tcp.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package transport
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "github.com/sagernet/sing-box/adapter"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing-box/log"
  10. "github.com/sagernet/sing-box/option"
  11. "github.com/sagernet/sing/common"
  12. "github.com/sagernet/sing/common/buf"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. mDNS "github.com/miekg/dns"
  16. )
  17. var _ adapter.DNSTransport = (*TCPTransport)(nil)
  18. func RegisterTCP(registry *dns.TransportRegistry) {
  19. dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP)
  20. }
  21. type TCPTransport struct {
  22. dns.TransportAdapter
  23. dialer N.Dialer
  24. serverAddr M.Socksaddr
  25. }
  26. func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
  27. transportDialer, err := dns.NewRemoteDialer(ctx, options)
  28. if err != nil {
  29. return nil, err
  30. }
  31. serverAddr := options.ServerOptions.Build()
  32. if serverAddr.Port == 0 {
  33. serverAddr.Port = 53
  34. }
  35. return &TCPTransport{
  36. TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options),
  37. dialer: transportDialer,
  38. serverAddr: serverAddr,
  39. }, nil
  40. }
  41. func (t *TCPTransport) Reset() {
  42. }
  43. func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  44. conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
  45. if err != nil {
  46. return nil, err
  47. }
  48. defer conn.Close()
  49. err = WriteMessage(conn, 0, message)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return ReadMessage(conn)
  54. }
  55. func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
  56. var responseLen uint16
  57. err := binary.Read(reader, binary.BigEndian, &responseLen)
  58. if err != nil {
  59. return nil, err
  60. }
  61. if responseLen < 10 {
  62. return nil, mDNS.ErrShortRead
  63. }
  64. buffer := buf.NewSize(int(responseLen))
  65. defer buffer.Release()
  66. _, err = buffer.ReadFullFrom(reader, int(responseLen))
  67. if err != nil {
  68. return nil, err
  69. }
  70. var message mDNS.Msg
  71. err = message.Unpack(buffer.Bytes())
  72. return &message, err
  73. }
  74. func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error {
  75. requestLen := message.Len()
  76. buffer := buf.NewSize(3 + requestLen)
  77. defer buffer.Release()
  78. common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen)))
  79. exMessage := *message
  80. exMessage.Id = messageId
  81. exMessage.Compress = true
  82. rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
  83. if err != nil {
  84. return err
  85. }
  86. buffer.Truncate(2 + len(rawMessage))
  87. return common.Error(writer.Write(buffer.Bytes()))
  88. }