dhcp_shared.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package dhcp
  2. import (
  3. "context"
  4. "math/rand"
  5. "strings"
  6. "time"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing/common/buf"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. mDNS "github.com/miekg/dns"
  14. )
  15. const (
  16. // net.maxDNSPacketSize
  17. maxDNSPacketSize = 1232
  18. )
  19. func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  20. var lastErr error
  21. for _, fqdn := range t.nameList(domain) {
  22. response, err := t.tryOneName(ctx, servers, fqdn, message)
  23. if err != nil {
  24. lastErr = err
  25. continue
  26. }
  27. return response, nil
  28. }
  29. return nil, lastErr
  30. }
  31. func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  32. returned := make(chan struct{})
  33. defer close(returned)
  34. type queryResult struct {
  35. response *mDNS.Msg
  36. err error
  37. }
  38. results := make(chan queryResult)
  39. startRacer := func(ctx context.Context, fqdn string) {
  40. response, err := t.tryOneName(ctx, servers, fqdn, message)
  41. if err == nil {
  42. if response.Rcode != mDNS.RcodeSuccess {
  43. err = dns.RcodeError(response.Rcode)
  44. } else if len(dns.MessageToAddresses(response)) == 0 {
  45. err = E.New(fqdn, ": empty result")
  46. }
  47. }
  48. select {
  49. case results <- queryResult{response, err}:
  50. case <-returned:
  51. }
  52. }
  53. queryCtx, queryCancel := context.WithCancel(ctx)
  54. defer queryCancel()
  55. var nameCount int
  56. for _, fqdn := range t.nameList(domain) {
  57. nameCount++
  58. go startRacer(queryCtx, fqdn)
  59. }
  60. var errors []error
  61. for {
  62. select {
  63. case <-ctx.Done():
  64. return nil, ctx.Err()
  65. case result := <-results:
  66. if result.err == nil {
  67. return result.response, nil
  68. }
  69. errors = append(errors, result.err)
  70. if len(errors) == nameCount {
  71. return nil, E.Errors(errors...)
  72. }
  73. }
  74. }
  75. }
  76. func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
  77. sLen := len(servers)
  78. var lastErr error
  79. for i := 0; i < t.attempts; i++ {
  80. for j := 0; j < sLen; j++ {
  81. server := servers[j]
  82. question := message.Question[0]
  83. question.Name = fqdn
  84. response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true)
  85. if err != nil {
  86. lastErr = err
  87. continue
  88. }
  89. return response, nil
  90. }
  91. }
  92. return nil, E.Cause(lastErr, fqdn)
  93. }
  94. func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
  95. if server.Port == 0 {
  96. server.Port = 53
  97. }
  98. var networks []string
  99. if useTCP {
  100. networks = []string{N.NetworkTCP}
  101. } else {
  102. networks = []string{N.NetworkUDP, N.NetworkTCP}
  103. }
  104. request := &mDNS.Msg{
  105. MsgHdr: mDNS.MsgHdr{
  106. Id: uint16(rand.Uint32()),
  107. RecursionDesired: true,
  108. AuthenticatedData: ad,
  109. },
  110. Question: []mDNS.Question{question},
  111. Compress: true,
  112. }
  113. request.SetEdns0(maxDNSPacketSize, false)
  114. buffer := buf.Get(buf.UDPBufferSize)
  115. defer buf.Put(buffer)
  116. for _, network := range networks {
  117. ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
  118. defer cancel()
  119. conn, err := t.dialer.DialContext(ctx, network, server)
  120. if err != nil {
  121. return nil, err
  122. }
  123. defer conn.Close()
  124. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  125. conn.SetDeadline(deadline)
  126. }
  127. rawMessage, err := request.PackBuffer(buffer)
  128. if err != nil {
  129. return nil, E.Cause(err, "pack request")
  130. }
  131. _, err = conn.Write(rawMessage)
  132. if err != nil {
  133. return nil, E.Cause(err, "write request")
  134. }
  135. n, err := conn.Read(buffer)
  136. if err != nil {
  137. return nil, E.Cause(err, "read response")
  138. }
  139. var response mDNS.Msg
  140. err = response.Unpack(buffer[:n])
  141. if err != nil {
  142. return nil, E.Cause(err, "unpack response")
  143. }
  144. if response.Truncated && network == N.NetworkUDP {
  145. continue
  146. }
  147. return &response, nil
  148. }
  149. panic("unexpected")
  150. }
  151. func (t *Transport) nameList(name string) []string {
  152. l := len(name)
  153. rooted := l > 0 && name[l-1] == '.'
  154. if l > 254 || l == 254 && !rooted {
  155. return nil
  156. }
  157. if rooted {
  158. if avoidDNS(name) {
  159. return nil
  160. }
  161. return []string{name}
  162. }
  163. hasNdots := strings.Count(name, ".") >= t.ndots
  164. name += "."
  165. // l++
  166. names := make([]string, 0, 1+len(t.search))
  167. if hasNdots && !avoidDNS(name) {
  168. names = append(names, name)
  169. }
  170. for _, suffix := range t.search {
  171. fqdn := name + suffix
  172. if !avoidDNS(fqdn) && len(fqdn) <= 254 {
  173. names = append(names, fqdn)
  174. }
  175. }
  176. if !hasNdots && !avoidDNS(name) {
  177. names = append(names, name)
  178. }
  179. return names
  180. }
  181. func avoidDNS(name string) bool {
  182. if name == "" {
  183. return true
  184. }
  185. if name[len(name)-1] == '.' {
  186. name = name[:len(name)-1]
  187. }
  188. return strings.HasSuffix(name, ".onion")
  189. }