local.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package local
  2. import (
  3. "context"
  4. "errors"
  5. "math/rand"
  6. "syscall"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing-box/dns"
  11. "github.com/sagernet/sing-box/dns/transport/hosts"
  12. "github.com/sagernet/sing-box/log"
  13. "github.com/sagernet/sing-box/option"
  14. "github.com/sagernet/sing/common/buf"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. M "github.com/sagernet/sing/common/metadata"
  17. N "github.com/sagernet/sing/common/network"
  18. mDNS "github.com/miekg/dns"
  19. )
  20. var _ adapter.DNSTransport = (*Transport)(nil)
  21. type Transport struct {
  22. dns.TransportAdapter
  23. ctx context.Context
  24. hosts *hosts.File
  25. dialer N.Dialer
  26. }
  27. func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
  28. transportDialer, err := dns.NewLocalDialer(ctx, options)
  29. if err != nil {
  30. return nil, err
  31. }
  32. return &Transport{
  33. TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options),
  34. ctx: ctx,
  35. hosts: hosts.NewFile(hosts.DefaultPath),
  36. dialer: transportDialer,
  37. }, nil
  38. }
  39. func (t *Transport) Start(stage adapter.StartStage) error {
  40. return nil
  41. }
  42. func (t *Transport) Close() error {
  43. return nil
  44. }
  45. func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  46. question := message.Question[0]
  47. domain := dns.FqdnToDomain(question.Name)
  48. if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
  49. addresses := t.hosts.Lookup(domain)
  50. if len(addresses) > 0 {
  51. return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
  52. }
  53. }
  54. systemConfig := getSystemDNSConfig(t.ctx)
  55. if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
  56. return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
  57. } else {
  58. return t.exchangeParallel(ctx, systemConfig, message, domain)
  59. }
  60. }
  61. func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  62. var lastErr error
  63. for _, fqdn := range systemConfig.nameList(domain) {
  64. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  65. if err != nil {
  66. lastErr = err
  67. continue
  68. }
  69. return response, nil
  70. }
  71. return nil, lastErr
  72. }
  73. func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  74. returned := make(chan struct{})
  75. defer close(returned)
  76. type queryResult struct {
  77. response *mDNS.Msg
  78. err error
  79. }
  80. results := make(chan queryResult)
  81. startRacer := func(ctx context.Context, fqdn string) {
  82. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  83. if err == nil {
  84. if response.Rcode != mDNS.RcodeSuccess {
  85. err = dns.RcodeError(response.Rcode)
  86. } else if len(dns.MessageToAddresses(response)) == 0 {
  87. err = E.New(fqdn, ": empty result")
  88. }
  89. }
  90. select {
  91. case results <- queryResult{response, err}:
  92. case <-returned:
  93. }
  94. }
  95. queryCtx, queryCancel := context.WithCancel(ctx)
  96. defer queryCancel()
  97. var nameCount int
  98. for _, fqdn := range systemConfig.nameList(domain) {
  99. nameCount++
  100. go startRacer(queryCtx, fqdn)
  101. }
  102. var errors []error
  103. for {
  104. select {
  105. case <-ctx.Done():
  106. return nil, ctx.Err()
  107. case result := <-results:
  108. if result.err == nil {
  109. return result.response, nil
  110. }
  111. errors = append(errors, result.err)
  112. if len(errors) == nameCount {
  113. return nil, E.Errors(errors...)
  114. }
  115. }
  116. }
  117. }
  118. func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
  119. serverOffset := config.serverOffset()
  120. sLen := uint32(len(config.servers))
  121. var lastErr error
  122. for i := 0; i < config.attempts; i++ {
  123. for j := uint32(0); j < sLen; j++ {
  124. server := config.servers[(serverOffset+j)%sLen]
  125. question := message.Question[0]
  126. question.Name = fqdn
  127. response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
  128. if err != nil {
  129. lastErr = err
  130. continue
  131. }
  132. return response, nil
  133. }
  134. }
  135. return nil, E.Cause(lastErr, fqdn)
  136. }
  137. func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
  138. if server.Port == 0 {
  139. server.Port = 53
  140. }
  141. var networks []string
  142. if useTCP {
  143. networks = []string{N.NetworkTCP}
  144. } else {
  145. networks = []string{N.NetworkUDP, N.NetworkTCP}
  146. }
  147. request := &mDNS.Msg{
  148. MsgHdr: mDNS.MsgHdr{
  149. Id: uint16(rand.Uint32()),
  150. RecursionDesired: true,
  151. AuthenticatedData: ad,
  152. },
  153. Question: []mDNS.Question{question},
  154. Compress: true,
  155. }
  156. request.SetEdns0(buf.UDPBufferSize, false)
  157. buffer := buf.Get(buf.UDPBufferSize)
  158. defer buf.Put(buffer)
  159. for _, network := range networks {
  160. ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
  161. defer cancel()
  162. conn, err := t.dialer.DialContext(ctx, network, server)
  163. if err != nil {
  164. return nil, err
  165. }
  166. defer conn.Close()
  167. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  168. conn.SetDeadline(deadline)
  169. }
  170. rawMessage, err := request.PackBuffer(buffer)
  171. if err != nil {
  172. return nil, E.Cause(err, "pack request")
  173. }
  174. _, err = conn.Write(rawMessage)
  175. if err != nil {
  176. if errors.Is(err, syscall.EMSGSIZE) && network == N.NetworkUDP {
  177. continue
  178. }
  179. return nil, E.Cause(err, "write request")
  180. }
  181. n, err := conn.Read(buffer)
  182. if err != nil {
  183. return nil, E.Cause(err, "read response")
  184. }
  185. var response mDNS.Msg
  186. err = response.Unpack(buffer[:n])
  187. if err != nil {
  188. return nil, E.Cause(err, "unpack response")
  189. }
  190. if response.Truncated && network == N.NetworkUDP {
  191. continue
  192. }
  193. return &response, nil
  194. }
  195. panic("unexpected")
  196. }