transport_tcp.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "net"
  6. "os"
  7. "sync"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/buf"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. "github.com/sagernet/sing/common/task"
  14. "github.com/sagernet/sing-box/adapter"
  15. "github.com/sagernet/sing-box/log"
  16. "golang.org/x/net/dns/dnsmessage"
  17. )
  18. var _ adapter.DNSTransport = (*TCPTransport)(nil)
  19. type TCPTransport struct {
  20. myTransportAdapter
  21. }
  22. func NewTCPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TCPTransport {
  23. return &TCPTransport{
  24. myTransportAdapter{
  25. ctx: ctx,
  26. dialer: dialer,
  27. logger: logger,
  28. destination: destination,
  29. done: make(chan struct{}),
  30. },
  31. }
  32. }
  33. func (t *TCPTransport) offer() (*dnsConnection, error) {
  34. t.access.RLock()
  35. connection := t.connection
  36. t.access.RUnlock()
  37. if connection != nil {
  38. select {
  39. case <-connection.done:
  40. default:
  41. return connection, nil
  42. }
  43. }
  44. t.access.Lock()
  45. connection = t.connection
  46. if connection != nil {
  47. select {
  48. case <-connection.done:
  49. default:
  50. t.access.Unlock()
  51. return connection, nil
  52. }
  53. }
  54. tcpConn, err := t.dialer.DialContext(t.ctx, "tcp", t.destination)
  55. if err != nil {
  56. return nil, err
  57. }
  58. connection = &dnsConnection{
  59. Conn: tcpConn,
  60. done: make(chan struct{}),
  61. callbacks: make(map[uint16]chan *dnsmessage.Message),
  62. }
  63. t.connection = connection
  64. t.access.Unlock()
  65. go t.newConnection(connection)
  66. return connection, nil
  67. }
  68. func (t *TCPTransport) newConnection(conn *dnsConnection) {
  69. defer close(conn.done)
  70. defer conn.Close()
  71. err := task.Any(t.ctx, func(ctx context.Context) error {
  72. return t.loopIn(conn)
  73. }, func(ctx context.Context) error {
  74. select {
  75. case <-ctx.Done():
  76. return nil
  77. case <-t.done:
  78. return os.ErrClosed
  79. }
  80. })
  81. conn.err = err
  82. if err != nil {
  83. t.logger.Debug("connection closed: ", err)
  84. }
  85. }
  86. func (t *TCPTransport) loopIn(conn *dnsConnection) error {
  87. _buffer := buf.StackNewSize(1024)
  88. defer common.KeepAlive(_buffer)
  89. buffer := common.Dup(_buffer)
  90. defer buffer.Release()
  91. for {
  92. buffer.FullReset()
  93. _, err := buffer.ReadFullFrom(conn, 2)
  94. if err != nil {
  95. return err
  96. }
  97. length := binary.BigEndian.Uint16(buffer.Bytes())
  98. if length > 512 {
  99. return E.New("invalid length received: ", length)
  100. }
  101. buffer.FullReset()
  102. _, err = buffer.ReadFullFrom(conn, int(length))
  103. if err != nil {
  104. return err
  105. }
  106. var message dnsmessage.Message
  107. err = message.Unpack(buffer.Bytes())
  108. if err != nil {
  109. return err
  110. }
  111. conn.access.Lock()
  112. callback, loaded := conn.callbacks[message.ID]
  113. if loaded {
  114. delete(conn.callbacks, message.ID)
  115. }
  116. conn.access.Unlock()
  117. if !loaded {
  118. continue
  119. }
  120. callback <- &message
  121. }
  122. }
  123. type dnsConnection struct {
  124. net.Conn
  125. done chan struct{}
  126. err error
  127. access sync.Mutex
  128. queryId uint16
  129. callbacks map[uint16]chan *dnsmessage.Message
  130. }
  131. func (t *TCPTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
  132. var connection *dnsConnection
  133. err := task.Run(ctx, func() error {
  134. var innerErr error
  135. connection, innerErr = t.offer()
  136. return innerErr
  137. })
  138. if err != nil {
  139. return nil, err
  140. }
  141. connection.access.Lock()
  142. connection.queryId++
  143. message.ID = connection.queryId
  144. callback := make(chan *dnsmessage.Message)
  145. connection.callbacks[message.ID] = callback
  146. connection.access.Unlock()
  147. _buffer := buf.StackNewSize(1024)
  148. defer common.KeepAlive(_buffer)
  149. buffer := common.Dup(_buffer)
  150. defer buffer.Release()
  151. length := buffer.Extend(2)
  152. rawMessage, err := message.AppendPack(buffer.Index(2))
  153. if err != nil {
  154. return nil, err
  155. }
  156. buffer.Truncate(2 + len(rawMessage))
  157. binary.BigEndian.PutUint16(length, uint16(len(rawMessage)))
  158. err = task.Run(ctx, func() error {
  159. return common.Error(connection.Write(buffer.Bytes()))
  160. })
  161. if err != nil {
  162. return nil, err
  163. }
  164. select {
  165. case response := <-callback:
  166. return response, nil
  167. case <-connection.done:
  168. return nil, connection.err
  169. case <-ctx.Done():
  170. return nil, ctx.Err()
  171. }
  172. }