udp.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package transport
  2. import (
  3. "context"
  4. "net"
  5. "os"
  6. "sync"
  7. "github.com/sagernet/sing-box/adapter"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/dns"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common/buf"
  13. "github.com/sagernet/sing/common/logger"
  14. M "github.com/sagernet/sing/common/metadata"
  15. N "github.com/sagernet/sing/common/network"
  16. mDNS "github.com/miekg/dns"
  17. )
  18. var _ adapter.DNSTransport = (*UDPTransport)(nil)
  19. func RegisterUDP(registry *dns.TransportRegistry) {
  20. dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP)
  21. }
  22. type UDPTransport struct {
  23. dns.TransportAdapter
  24. logger logger.ContextLogger
  25. dialer N.Dialer
  26. serverAddr M.Socksaddr
  27. udpSize int
  28. tcpTransport *TCPTransport
  29. access sync.Mutex
  30. conn *dnsConnection
  31. done chan struct{}
  32. }
  33. func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
  34. transportDialer, err := dns.NewRemoteDialer(ctx, options)
  35. if err != nil {
  36. return nil, err
  37. }
  38. serverAddr := options.ServerOptions.Build()
  39. if serverAddr.Port == 0 {
  40. serverAddr.Port = 53
  41. }
  42. return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
  43. }
  44. func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
  45. return &UDPTransport{
  46. TransportAdapter: adapter,
  47. logger: logger,
  48. dialer: dialer,
  49. serverAddr: serverAddr,
  50. udpSize: 512,
  51. tcpTransport: &TCPTransport{
  52. dialer: dialer,
  53. serverAddr: serverAddr,
  54. },
  55. done: make(chan struct{}),
  56. }
  57. }
  58. func (t *UDPTransport) Reset() {
  59. t.access.Lock()
  60. defer t.access.Unlock()
  61. close(t.done)
  62. t.done = make(chan struct{})
  63. }
  64. func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  65. response, err := t.exchange(ctx, message)
  66. if err != nil {
  67. return nil, err
  68. }
  69. if response.Truncated {
  70. t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
  71. return t.tcpTransport.Exchange(ctx, message)
  72. }
  73. return response, nil
  74. }
  75. func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  76. conn, err := t.open(ctx)
  77. if err != nil {
  78. return nil, err
  79. }
  80. if edns0Opt := message.IsEdns0(); edns0Opt != nil {
  81. if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
  82. t.udpSize = udpSize
  83. }
  84. }
  85. buffer := buf.NewSize(1 + message.Len())
  86. defer buffer.Release()
  87. exMessage := *message
  88. exMessage.Compress = true
  89. messageId := message.Id
  90. callback := &dnsCallback{
  91. done: make(chan struct{}),
  92. }
  93. conn.access.Lock()
  94. conn.queryId++
  95. exMessage.Id = conn.queryId
  96. conn.callbacks[exMessage.Id] = callback
  97. conn.access.Unlock()
  98. defer func() {
  99. conn.access.Lock()
  100. delete(conn.callbacks, messageId)
  101. conn.access.Unlock()
  102. callback.access.Lock()
  103. select {
  104. case <-callback.done:
  105. default:
  106. close(callback.done)
  107. }
  108. callback.access.Unlock()
  109. }()
  110. rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
  111. if err != nil {
  112. return nil, err
  113. }
  114. _, err = conn.Write(rawMessage)
  115. if err != nil {
  116. conn.Close(err)
  117. return nil, err
  118. }
  119. select {
  120. case <-callback.done:
  121. callback.message.Id = messageId
  122. return callback.message, nil
  123. case <-conn.done:
  124. return nil, conn.err
  125. case <-t.done:
  126. return nil, os.ErrClosed
  127. case <-ctx.Done():
  128. conn.Close(ctx.Err())
  129. return nil, ctx.Err()
  130. }
  131. }
  132. func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
  133. t.access.Lock()
  134. defer t.access.Unlock()
  135. conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
  136. if err != nil {
  137. return nil, err
  138. }
  139. dnsConn := &dnsConnection{
  140. Conn: conn,
  141. done: make(chan struct{}),
  142. callbacks: make(map[uint16]*dnsCallback),
  143. }
  144. go t.recvLoop(dnsConn)
  145. return dnsConn, nil
  146. }
  147. func (t *UDPTransport) recvLoop(conn *dnsConnection) {
  148. for {
  149. buffer := buf.NewSize(t.udpSize)
  150. _, err := buffer.ReadOnceFrom(conn)
  151. if err != nil {
  152. buffer.Release()
  153. conn.Close(err)
  154. return
  155. }
  156. var message mDNS.Msg
  157. err = message.Unpack(buffer.Bytes())
  158. buffer.Release()
  159. if err != nil {
  160. conn.Close(err)
  161. return
  162. }
  163. conn.access.RLock()
  164. callback, loaded := conn.callbacks[message.Id]
  165. conn.access.RUnlock()
  166. if !loaded {
  167. continue
  168. }
  169. callback.access.Lock()
  170. select {
  171. case <-callback.done:
  172. default:
  173. callback.message = &message
  174. close(callback.done)
  175. }
  176. callback.access.Unlock()
  177. }
  178. }
  179. type dnsConnection struct {
  180. net.Conn
  181. access sync.RWMutex
  182. done chan struct{}
  183. closeOnce sync.Once
  184. err error
  185. queryId uint16
  186. callbacks map[uint16]*dnsCallback
  187. }
  188. func (c *dnsConnection) Close(err error) {
  189. c.access.Lock()
  190. defer c.access.Unlock()
  191. c.closeOnce.Do(func() {
  192. close(c.done)
  193. c.err = err
  194. })
  195. c.Conn.Close()
  196. }
  197. type dnsCallback struct {
  198. access sync.Mutex
  199. message *mDNS.Msg
  200. done chan struct{}
  201. }