transport_udp.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package dns
  2. import (
  3. "context"
  4. "os"
  5. "github.com/sagernet/sing-box/adapter"
  6. "github.com/sagernet/sing-box/log"
  7. "github.com/sagernet/sing/common"
  8. "github.com/sagernet/sing/common/buf"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. "github.com/sagernet/sing/common/task"
  12. "golang.org/x/net/dns/dnsmessage"
  13. )
  14. var _ adapter.DNSTransport = (*UDPTransport)(nil)
  15. type UDPTransport struct {
  16. myTransportAdapter
  17. }
  18. func NewUDPTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, destination M.Socksaddr) *UDPTransport {
  19. return &UDPTransport{
  20. myTransportAdapter{
  21. ctx: ctx,
  22. dialer: dialer,
  23. logger: logger,
  24. destination: destination,
  25. done: make(chan struct{}),
  26. },
  27. }
  28. }
  29. func (t *UDPTransport) offer() (*dnsConnection, error) {
  30. t.access.RLock()
  31. connection := t.connection
  32. t.access.RUnlock()
  33. if connection != nil {
  34. select {
  35. case <-connection.done:
  36. default:
  37. return connection, nil
  38. }
  39. }
  40. t.access.Lock()
  41. connection = t.connection
  42. if connection != nil {
  43. select {
  44. case <-connection.done:
  45. default:
  46. t.access.Unlock()
  47. return connection, nil
  48. }
  49. }
  50. tcpConn, err := t.dialer.DialContext(t.ctx, "udp", t.destination)
  51. if err != nil {
  52. return nil, err
  53. }
  54. connection = &dnsConnection{
  55. Conn: tcpConn,
  56. done: make(chan struct{}),
  57. callbacks: make(map[uint16]chan *dnsmessage.Message),
  58. }
  59. t.connection = connection
  60. t.access.Unlock()
  61. go t.newConnection(connection)
  62. return connection, nil
  63. }
  64. func (t *UDPTransport) newConnection(conn *dnsConnection) {
  65. defer close(conn.done)
  66. defer conn.Close()
  67. err := task.Any(t.ctx, func(ctx context.Context) error {
  68. return t.loopIn(conn)
  69. }, func(ctx context.Context) error {
  70. select {
  71. case <-ctx.Done():
  72. return nil
  73. case <-t.done:
  74. return os.ErrClosed
  75. }
  76. })
  77. conn.err = err
  78. if err != nil {
  79. t.logger.Debug("connection closed: ", err)
  80. }
  81. }
  82. func (t *UDPTransport) loopIn(conn *dnsConnection) error {
  83. _buffer := buf.StackNewSize(1024)
  84. defer common.KeepAlive(_buffer)
  85. buffer := common.Dup(_buffer)
  86. defer buffer.Release()
  87. for {
  88. buffer.FullReset()
  89. _, err := buffer.ReadFrom(conn)
  90. if err != nil {
  91. return err
  92. }
  93. var message dnsmessage.Message
  94. err = message.Unpack(buffer.Bytes())
  95. if err != nil {
  96. return err
  97. }
  98. conn.access.Lock()
  99. callback, loaded := conn.callbacks[message.ID]
  100. if loaded {
  101. delete(conn.callbacks, message.ID)
  102. }
  103. conn.access.Unlock()
  104. if !loaded {
  105. continue
  106. }
  107. callback <- &message
  108. }
  109. }
  110. func (t *UDPTransport) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
  111. var connection *dnsConnection
  112. err := task.Run(ctx, func() error {
  113. var innerErr error
  114. connection, innerErr = t.offer()
  115. return innerErr
  116. })
  117. if err != nil {
  118. return nil, err
  119. }
  120. connection.access.Lock()
  121. connection.queryId++
  122. message.ID = connection.queryId
  123. callback := make(chan *dnsmessage.Message)
  124. connection.callbacks[message.ID] = callback
  125. connection.access.Unlock()
  126. _buffer := buf.StackNewSize(1024)
  127. defer common.KeepAlive(_buffer)
  128. buffer := common.Dup(_buffer)
  129. defer buffer.Release()
  130. rawMessage, err := message.AppendPack(buffer.Index(0))
  131. if err != nil {
  132. return nil, err
  133. }
  134. buffer.Truncate(len(rawMessage))
  135. err = task.Run(ctx, func() error {
  136. return common.Error(connection.Write(buffer.Bytes()))
  137. })
  138. if err != nil {
  139. return nil, err
  140. }
  141. select {
  142. case response := <-callback:
  143. return response, nil
  144. case <-connection.done:
  145. return nil, connection.err
  146. case <-ctx.Done():
  147. return nil, ctx.Err()
  148. }
  149. }