local_shared.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package local
  2. import (
  3. "context"
  4. "errors"
  5. "math/rand"
  6. "syscall"
  7. "time"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing-box/dns/transport"
  10. "github.com/sagernet/sing/common/buf"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. mDNS "github.com/miekg/dns"
  15. )
  16. func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  17. systemConfig := getSystemDNSConfig(t.ctx)
  18. if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
  19. return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
  20. } else {
  21. return t.exchangeParallel(ctx, systemConfig, message, domain)
  22. }
  23. }
  24. func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  25. var lastErr error
  26. for _, fqdn := range systemConfig.nameList(domain) {
  27. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  28. if err != nil {
  29. lastErr = err
  30. continue
  31. }
  32. return response, nil
  33. }
  34. return nil, lastErr
  35. }
  36. func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  37. returned := make(chan struct{})
  38. defer close(returned)
  39. type queryResult struct {
  40. response *mDNS.Msg
  41. err error
  42. }
  43. results := make(chan queryResult)
  44. startRacer := func(ctx context.Context, fqdn string) {
  45. response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
  46. if err == nil {
  47. if response.Rcode != mDNS.RcodeSuccess {
  48. err = dns.RcodeError(response.Rcode)
  49. } else if len(dns.MessageToAddresses(response)) == 0 {
  50. err = E.New(fqdn, ": empty result")
  51. }
  52. }
  53. select {
  54. case results <- queryResult{response, err}:
  55. case <-returned:
  56. }
  57. }
  58. queryCtx, queryCancel := context.WithCancel(ctx)
  59. defer queryCancel()
  60. var nameCount int
  61. for _, fqdn := range systemConfig.nameList(domain) {
  62. nameCount++
  63. go startRacer(queryCtx, fqdn)
  64. }
  65. var errors []error
  66. for {
  67. select {
  68. case <-ctx.Done():
  69. return nil, ctx.Err()
  70. case result := <-results:
  71. if result.err == nil {
  72. return result.response, nil
  73. }
  74. errors = append(errors, result.err)
  75. if len(errors) == nameCount {
  76. return nil, E.Errors(errors...)
  77. }
  78. }
  79. }
  80. }
  81. func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
  82. serverOffset := config.serverOffset()
  83. sLen := uint32(len(config.servers))
  84. var lastErr error
  85. for i := 0; i < config.attempts; i++ {
  86. for j := uint32(0); j < sLen; j++ {
  87. server := config.servers[(serverOffset+j)%sLen]
  88. question := message.Question[0]
  89. question.Name = fqdn
  90. response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
  91. if err != nil {
  92. lastErr = err
  93. continue
  94. }
  95. return response, nil
  96. }
  97. }
  98. return nil, E.Cause(lastErr, fqdn)
  99. }
  100. func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
  101. if server.Port == 0 {
  102. server.Port = 53
  103. }
  104. request := &mDNS.Msg{
  105. MsgHdr: mDNS.MsgHdr{
  106. Id: uint16(rand.Uint32()),
  107. RecursionDesired: true,
  108. AuthenticatedData: ad,
  109. },
  110. Question: []mDNS.Question{question},
  111. Compress: true,
  112. }
  113. request.SetEdns0(buf.UDPBufferSize, false)
  114. if !useTCP {
  115. return t.exchangeUDP(ctx, server, request, timeout)
  116. } else {
  117. return t.exchangeTCP(ctx, server, request, timeout)
  118. }
  119. }
  120. func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
  121. conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
  122. if err != nil {
  123. return nil, err
  124. }
  125. defer conn.Close()
  126. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  127. newDeadline := time.Now().Add(timeout)
  128. if deadline.After(newDeadline) {
  129. deadline = newDeadline
  130. }
  131. conn.SetDeadline(deadline)
  132. }
  133. buffer := buf.Get(buf.UDPBufferSize)
  134. defer buf.Put(buffer)
  135. rawMessage, err := request.PackBuffer(buffer)
  136. if err != nil {
  137. return nil, E.Cause(err, "pack request")
  138. }
  139. _, err = conn.Write(rawMessage)
  140. if err != nil {
  141. if errors.Is(err, syscall.EMSGSIZE) {
  142. return t.exchangeTCP(ctx, server, request, timeout)
  143. }
  144. return nil, E.Cause(err, "write request")
  145. }
  146. n, err := conn.Read(buffer)
  147. if err != nil {
  148. if errors.Is(err, syscall.EMSGSIZE) {
  149. return t.exchangeTCP(ctx, server, request, timeout)
  150. }
  151. return nil, E.Cause(err, "read response")
  152. }
  153. var response mDNS.Msg
  154. err = response.Unpack(buffer[:n])
  155. if err != nil {
  156. return nil, E.Cause(err, "unpack response")
  157. }
  158. if response.Truncated {
  159. return t.exchangeTCP(ctx, server, request, timeout)
  160. }
  161. return &response, nil
  162. }
  163. func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
  164. conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
  165. if err != nil {
  166. return nil, err
  167. }
  168. defer conn.Close()
  169. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  170. newDeadline := time.Now().Add(timeout)
  171. if deadline.After(newDeadline) {
  172. deadline = newDeadline
  173. }
  174. conn.SetDeadline(deadline)
  175. }
  176. err = transport.WriteMessage(conn, 0, request)
  177. if err != nil {
  178. return nil, err
  179. }
  180. return transport.ReadMessage(conn)
  181. }