tcp.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package transport
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "github.com/sagernet/sing-box/adapter"
  7. "github.com/sagernet/sing-box/common/dialer"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/dns"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common"
  13. "github.com/sagernet/sing/common/buf"
  14. E "github.com/sagernet/sing/common/exceptions"
  15. M "github.com/sagernet/sing/common/metadata"
  16. N "github.com/sagernet/sing/common/network"
  17. mDNS "github.com/miekg/dns"
  18. )
  19. var _ adapter.DNSTransport = (*TCPTransport)(nil)
  20. func RegisterTCP(registry *dns.TransportRegistry) {
  21. dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP)
  22. }
  23. type TCPTransport struct {
  24. dns.TransportAdapter
  25. dialer N.Dialer
  26. serverAddr M.Socksaddr
  27. }
  28. func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
  29. transportDialer, err := dns.NewRemoteDialer(ctx, options)
  30. if err != nil {
  31. return nil, err
  32. }
  33. serverAddr := options.DNSServerAddressOptions.Build()
  34. if serverAddr.Port == 0 {
  35. serverAddr.Port = 53
  36. }
  37. if !serverAddr.IsValid() {
  38. return nil, E.New("invalid server address: ", serverAddr)
  39. }
  40. return &TCPTransport{
  41. TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options),
  42. dialer: transportDialer,
  43. serverAddr: serverAddr,
  44. }, nil
  45. }
  46. func (t *TCPTransport) Start(stage adapter.StartStage) error {
  47. if stage != adapter.StartStateStart {
  48. return nil
  49. }
  50. return dialer.InitializeDetour(t.dialer)
  51. }
  52. func (t *TCPTransport) Close() error {
  53. return nil
  54. }
  55. func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  56. conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
  57. if err != nil {
  58. return nil, err
  59. }
  60. defer conn.Close()
  61. err = WriteMessage(conn, 0, message)
  62. if err != nil {
  63. return nil, err
  64. }
  65. return ReadMessage(conn)
  66. }
  67. func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
  68. var responseLen uint16
  69. err := binary.Read(reader, binary.BigEndian, &responseLen)
  70. if err != nil {
  71. return nil, err
  72. }
  73. if responseLen < 10 {
  74. return nil, mDNS.ErrShortRead
  75. }
  76. buffer := buf.NewSize(int(responseLen))
  77. defer buffer.Release()
  78. _, err = buffer.ReadFullFrom(reader, int(responseLen))
  79. if err != nil {
  80. return nil, err
  81. }
  82. var message mDNS.Msg
  83. err = message.Unpack(buffer.Bytes())
  84. return &message, err
  85. }
  86. func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error {
  87. requestLen := message.Len()
  88. buffer := buf.NewSize(3 + requestLen)
  89. defer buffer.Release()
  90. common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen)))
  91. exMessage := *message
  92. exMessage.Id = messageId
  93. exMessage.Compress = true
  94. rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
  95. if err != nil {
  96. return err
  97. }
  98. buffer.Truncate(2 + len(rawMessage))
  99. return common.Error(writer.Write(buffer.Bytes()))
  100. }