transport_tls.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package dns
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "encoding/binary"
  6. "os"
  7. "github.com/sagernet/sing/common"
  8. "github.com/sagernet/sing/common/buf"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. "github.com/sagernet/sing/common/task"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/log"
  15. "golang.org/x/net/dns/dnsmessage"
  16. )
  17. var _ adapter.DNSTransport = (*TLSTransport)(nil)
  18. type TLSTransport struct {
  19. myTransportAdapter
  20. }
  21. func NewTLSTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *TLSTransport {
  22. return &TLSTransport{
  23. myTransportAdapter{
  24. ctx: ctx,
  25. dialer: dialer,
  26. logger: logger,
  27. destination: destination,
  28. done: make(chan struct{}),
  29. },
  30. }
  31. }
  32. func (t *TLSTransport) offer(ctx context.Context) (*dnsConnection, error) {
  33. t.access.RLock()
  34. connection := t.connection
  35. t.access.RUnlock()
  36. if connection != nil {
  37. select {
  38. case <-connection.done:
  39. default:
  40. return connection, nil
  41. }
  42. }
  43. t.access.Lock()
  44. connection = t.connection
  45. if connection != nil {
  46. select {
  47. case <-connection.done:
  48. default:
  49. t.access.Unlock()
  50. return connection, nil
  51. }
  52. }
  53. tcpConn, err := t.dialer.DialContext(t.ctx, "tcp", t.destination)
  54. if err != nil {
  55. return nil, err
  56. }
  57. tlsConn := tls.Client(tcpConn, &tls.Config{
  58. ServerName: t.destination.AddrString(),
  59. })
  60. err = task.Run(t.ctx, func() error {
  61. return tlsConn.HandshakeContext(ctx)
  62. })
  63. if err != nil {
  64. return nil, err
  65. }
  66. connection = &dnsConnection{
  67. Conn: tlsConn,
  68. done: make(chan struct{}),
  69. callbacks: make(map[uint16]chan *dnsmessage.Message),
  70. }
  71. t.connection = connection
  72. t.access.Unlock()
  73. go t.newConnection(connection)
  74. return connection, nil
  75. }
  76. func (t *TLSTransport) newConnection(conn *dnsConnection) {
  77. defer close(conn.done)
  78. defer conn.Close()
  79. err := task.Any(t.ctx, func(ctx context.Context) error {
  80. return t.loopIn(conn)
  81. }, func(ctx context.Context) error {
  82. select {
  83. case <-ctx.Done():
  84. return nil
  85. case <-t.done:
  86. return os.ErrClosed
  87. }
  88. })
  89. conn.err = err
  90. if err != nil {
  91. t.logger.Debug("connection closed: ", err)
  92. }
  93. }
  94. func (t *TLSTransport) loopIn(conn *dnsConnection) error {
  95. _buffer := buf.StackNewSize(1024)
  96. defer common.KeepAlive(_buffer)
  97. buffer := common.Dup(_buffer)
  98. defer buffer.Release()
  99. for {
  100. buffer.FullReset()
  101. _, err := buffer.ReadFullFrom(conn, 2)
  102. if err != nil {
  103. return err
  104. }
  105. length := binary.BigEndian.Uint16(buffer.Bytes())
  106. if length > 512 {
  107. return E.New("invalid length received: ", length)
  108. }
  109. buffer.FullReset()
  110. _, err = buffer.ReadFullFrom(conn, int(length))
  111. if err != nil {
  112. return err
  113. }
  114. var message dnsmessage.Message
  115. err = message.Unpack(buffer.Bytes())
  116. if err != nil {
  117. return err
  118. }
  119. conn.access.Lock()
  120. callback, loaded := conn.callbacks[message.ID]
  121. if loaded {
  122. delete(conn.callbacks, message.ID)
  123. }
  124. conn.access.Unlock()
  125. if !loaded {
  126. continue
  127. }
  128. callback <- &message
  129. }
  130. }
  131. func (t *TLSTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
  132. var connection *dnsConnection
  133. err := task.Run(ctx, func() error {
  134. var innerErr error
  135. connection, innerErr = t.offer(ctx)
  136. return innerErr
  137. })
  138. if err != nil {
  139. return nil, err
  140. }
  141. connection.access.Lock()
  142. connection.queryId++
  143. message.ID = connection.queryId
  144. callback := make(chan *dnsmessage.Message)
  145. connection.callbacks[message.ID] = callback
  146. connection.access.Unlock()
  147. _buffer := buf.StackNewSize(1024)
  148. defer common.KeepAlive(_buffer)
  149. buffer := common.Dup(_buffer)
  150. defer buffer.Release()
  151. length := buffer.Extend(2)
  152. rawMessage, err := message.AppendPack(buffer.Index(2))
  153. if err != nil {
  154. return nil, err
  155. }
  156. buffer.Truncate(2 + len(rawMessage))
  157. binary.BigEndian.PutUint16(length, uint16(len(rawMessage)))
  158. err = task.Run(ctx, func() error {
  159. return common.Error(connection.Write(buffer.Bytes()))
  160. })
  161. if err != nil {
  162. return nil, err
  163. }
  164. select {
  165. case response := <-callback:
  166. return response, nil
  167. case <-connection.done:
  168. return nil, connection.err
  169. case <-ctx.Done():
  170. return nil, ctx.Err()
  171. }
  172. }