udp.go 4.9 KB


  1. package transport
  2. import (
  3. "context"
  4. "net"
  5. "os"
  6. "sync"
  7. "github.com/sagernet/sing-box/adapter"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/dns"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common/buf"
  13. "github.com/sagernet/sing/common/logger"
  14. M "github.com/sagernet/sing/common/metadata"
  15. N "github.com/sagernet/sing/common/network"
  16. mDNS "github.com/miekg/dns"
  17. )
  18. var _ adapter.DNSTransport = (*UDPTransport)(nil)
  19. func RegisterUDP(registry *dns.TransportRegistry) {
  20. dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP)
  21. }
  22. type UDPTransport struct {
  23. dns.TransportAdapter
  24. logger logger.ContextLogger
  25. dialer N.Dialer
  26. serverAddr M.Socksaddr
  27. udpSize int
  28. tcpTransport *TCPTransport
  29. access sync.Mutex
  30. conn *dnsConnection
  31. done chan struct{}
  32. }
  33. func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
  34. transportDialer, err := dns.NewRemoteDialer(ctx, options)
  35. if err != nil {
  36. return nil, err
  37. }
  38. serverAddr := options.DNSServerAddressOptions.Build()
  39. if serverAddr.Port == 0 {
  40. serverAddr.Port = 53
  41. }
  42. return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
  43. }
  44. func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
  45. return &UDPTransport{
  46. TransportAdapter: adapter,
  47. logger: logger,
  48. dialer: dialer,
  49. serverAddr: serverAddr,
  50. udpSize: 512,
  51. tcpTransport: &TCPTransport{
  52. dialer: dialer,
  53. serverAddr: serverAddr,
  54. },
  55. done: make(chan struct{}),
  56. }
  57. }
  58. func (t *UDPTransport) Reset() {
  59. t.access.Lock()
  60. defer t.access.Unlock()
  61. close(t.done)
  62. t.done = make(chan struct{})
  63. }
  64. func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  65. response, err := t.exchange(ctx, message)
  66. if err != nil {
  67. return nil, err
  68. }
  69. if response.Truncated {
  70. t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
  71. return t.tcpTransport.Exchange(ctx, message)
  72. }
  73. return response, nil
  74. }
  75. func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  76. conn, err := t.open(ctx)
  77. if err != nil {
  78. return nil, err
  79. }
  80. if edns0Opt := message.IsEdns0(); edns0Opt != nil {
  81. if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
  82. t.udpSize = udpSize
  83. }
  84. }
  85. buffer := buf.NewSize(1 + message.Len())
  86. defer buffer.Release()
  87. exMessage := *message
  88. exMessage.Compress = true
  89. messageId := message.Id
  90. callback := &dnsCallback{
  91. done: make(chan struct{}),
  92. }
  93. conn.access.Lock()
  94. conn.queryId++
  95. exMessage.Id = conn.queryId
  96. conn.callbacks[exMessage.Id] = callback
  97. conn.access.Unlock()
  98. defer func() {
  99. conn.access.Lock()
  100. delete(conn.callbacks, messageId)
  101. conn.access.Unlock()
  102. }()
  103. rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
  104. if err != nil {
  105. return nil, err
  106. }
  107. _, err = conn.Write(rawMessage)
  108. if err != nil {
  109. conn.Close(err)
  110. return nil, err
  111. }
  112. select {
  113. case <-callback.done:
  114. callback.message.Id = messageId
  115. return callback.message, nil
  116. case <-conn.done:
  117. return nil, conn.err
  118. case <-t.done:
  119. return nil, os.ErrClosed
  120. case <-ctx.Done():
  121. conn.Close(ctx.Err())
  122. return nil, ctx.Err()
  123. }
  124. }
  125. func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
  126. t.access.Lock()
  127. defer t.access.Unlock()
  128. if t.conn != nil {
  129. select {
  130. case <-t.conn.done:
  131. default:
  132. return t.conn, nil
  133. }
  134. }
  135. conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
  136. if err != nil {
  137. return nil, err
  138. }
  139. dnsConn := &dnsConnection{
  140. Conn: conn,
  141. done: make(chan struct{}),
  142. callbacks: make(map[uint16]*dnsCallback),
  143. }
  144. go t.recvLoop(dnsConn)
  145. t.conn = dnsConn
  146. return dnsConn, nil
  147. }
  148. func (t *UDPTransport) recvLoop(conn *dnsConnection) {
  149. for {
  150. buffer := buf.NewSize(t.udpSize)
  151. _, err := buffer.ReadOnceFrom(conn)
  152. if err != nil {
  153. buffer.Release()
  154. conn.Close(err)
  155. return
  156. }
  157. var message mDNS.Msg
  158. err = message.Unpack(buffer.Bytes())
  159. buffer.Release()
  160. if err != nil {
  161. conn.Close(err)
  162. return
  163. }
  164. conn.access.RLock()
  165. callback, loaded := conn.callbacks[message.Id]
  166. conn.access.RUnlock()
  167. if !loaded {
  168. continue
  169. }
  170. callback.access.Lock()
  171. select {
  172. case <-callback.done:
  173. default:
  174. callback.message = &message
  175. close(callback.done)
  176. }
  177. callback.access.Unlock()
  178. }
  179. }
  180. type dnsConnection struct {
  181. net.Conn
  182. access sync.RWMutex
  183. done chan struct{}
  184. closeOnce sync.Once
  185. err error
  186. queryId uint16
  187. callbacks map[uint16]*dnsCallback
  188. }
  189. func (c *dnsConnection) Close(err error) {
  190. c.closeOnce.Do(func() {
  191. close(c.done)
  192. c.err = err
  193. })
  194. c.Conn.Close()
  195. }
  196. type dnsCallback struct {
  197. access sync.Mutex
  198. message *mDNS.Msg
  199. done chan struct{}
  200. }