local.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package local
  2. import (
  3. "context"
  4. "math/rand"
  5. "time"
  6. "github.com/sagernet/sing-box/adapter"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing-box/dns/transport/hosts"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common/buf"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. M "github.com/sagernet/sing/common/metadata"
  15. N "github.com/sagernet/sing/common/network"
  16. mDNS "github.com/miekg/dns"
  17. )
  18. func RegisterTransport(registry *dns.TransportRegistry) {
  19. dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport)
  20. }
  21. var _ adapter.DNSTransport = (*Transport)(nil)
  22. type Transport struct {
  23. dns.TransportAdapter
  24. hosts *hosts.File
  25. dialer N.Dialer
  26. }
  27. func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
  28. transportDialer, err := dns.NewLocalDialer(ctx, options)
  29. if err != nil {
  30. return nil, err
  31. }
  32. return &Transport{
  33. TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options),
  34. hosts: hosts.NewFile(hosts.DefaultPath),
  35. dialer: transportDialer,
  36. }, nil
  37. }
  38. func (t *Transport) Reset() {
  39. }
  40. func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  41. question := message.Question[0]
  42. domain := dns.FqdnToDomain(question.Name)
  43. if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
  44. addresses := t.hosts.Lookup(domain)
  45. if len(addresses) > 0 {
  46. return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
  47. }
  48. }
  49. systemConfig := getSystemDNSConfig()
  50. if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
  51. return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
  52. } else {
  53. return t.exchangeParallel(ctx, systemConfig, message, domain)
  54. }
  55. }
  56. func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  57. var lastErr error
  58. for _, fqdn := range systemConfig.nameList(domain) {
  59. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  60. if err != nil {
  61. lastErr = err
  62. continue
  63. }
  64. return response, nil
  65. }
  66. return nil, lastErr
  67. }
  68. func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  69. returned := make(chan struct{})
  70. defer close(returned)
  71. type queryResult struct {
  72. response *mDNS.Msg
  73. err error
  74. }
  75. results := make(chan queryResult)
  76. startRacer := func(ctx context.Context, fqdn string) {
  77. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  78. addresses, _ := dns.MessageToAddresses(response)
  79. if len(addresses) == 0 {
  80. err = E.New(fqdn, ": empty result")
  81. }
  82. select {
  83. case results <- queryResult{response, err}:
  84. case <-returned:
  85. }
  86. }
  87. queryCtx, queryCancel := context.WithCancel(ctx)
  88. defer queryCancel()
  89. var nameCount int
  90. for _, fqdn := range systemConfig.nameList(domain) {
  91. nameCount++
  92. go startRacer(queryCtx, fqdn)
  93. }
  94. var errors []error
  95. for {
  96. select {
  97. case <-ctx.Done():
  98. return nil, ctx.Err()
  99. case result := <-results:
  100. if result.err == nil {
  101. return result.response, nil
  102. }
  103. errors = append(errors, result.err)
  104. if len(errors) == nameCount {
  105. return nil, E.Errors(errors...)
  106. }
  107. }
  108. }
  109. }
  110. func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
  111. serverOffset := config.serverOffset()
  112. sLen := uint32(len(config.servers))
  113. var lastErr error
  114. for i := 0; i < config.attempts; i++ {
  115. for j := uint32(0); j < sLen; j++ {
  116. server := config.servers[(serverOffset+j)%sLen]
  117. question := message.Question[0]
  118. question.Name = fqdn
  119. response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
  120. if err != nil {
  121. lastErr = err
  122. continue
  123. }
  124. return response, nil
  125. }
  126. }
  127. return nil, E.Cause(lastErr, fqdn)
  128. }
  129. func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
  130. var networks []string
  131. if useTCP {
  132. networks = []string{N.NetworkTCP}
  133. } else {
  134. networks = []string{N.NetworkUDP, N.NetworkTCP}
  135. }
  136. request := &mDNS.Msg{
  137. MsgHdr: mDNS.MsgHdr{
  138. Id: uint16(rand.Uint32()),
  139. RecursionDesired: true,
  140. AuthenticatedData: ad,
  141. },
  142. Question: []mDNS.Question{question},
  143. Compress: true,
  144. }
  145. request.SetEdns0(maxDNSPacketSize, false)
  146. buffer := buf.Get(buf.UDPBufferSize)
  147. defer buf.Put(buffer)
  148. for _, network := range networks {
  149. ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
  150. defer cancel()
  151. conn, err := t.dialer.DialContext(ctx, network, server)
  152. if err != nil {
  153. return nil, err
  154. }
  155. defer conn.Close()
  156. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  157. conn.SetDeadline(deadline)
  158. }
  159. rawMessage, err := request.PackBuffer(buffer)
  160. if err != nil {
  161. return nil, E.Cause(err, "pack request")
  162. }
  163. _, err = conn.Write(rawMessage)
  164. if err != nil {
  165. return nil, E.Cause(err, "write request")
  166. }
  167. n, err := conn.Read(buffer)
  168. if err != nil {
  169. return nil, E.Cause(err, "read response")
  170. }
  171. var response mDNS.Msg
  172. err = response.Unpack(buffer[:n])
  173. if err != nil {
  174. return nil, E.Cause(err, "unpack response")
  175. }
  176. if response.Truncated && network == N.NetworkUDP {
  177. continue
  178. }
  179. return &response, nil
  180. }
  181. panic("unexpected")
  182. }