dhcp_shared.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package dhcp
  2. import (
  3. "context"
  4. "errors"
  5. "math/rand"
  6. "strings"
  7. "syscall"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing-box/dns/transport"
  10. "github.com/sagernet/sing/common/buf"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. mDNS "github.com/miekg/dns"
  15. )
  16. func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  17. var lastErr error
  18. for _, fqdn := range t.nameList(domain) {
  19. response, err := t.tryOneName(ctx, servers, fqdn, message)
  20. if err != nil {
  21. lastErr = err
  22. continue
  23. }
  24. return response, nil
  25. }
  26. return nil, lastErr
  27. }
  28. func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
  29. returned := make(chan struct{})
  30. defer close(returned)
  31. type queryResult struct {
  32. response *mDNS.Msg
  33. err error
  34. }
  35. results := make(chan queryResult)
  36. startRacer := func(ctx context.Context, fqdn string) {
  37. response, err := t.tryOneName(ctx, servers, fqdn, message)
  38. if err == nil {
  39. if response.Rcode != mDNS.RcodeSuccess {
  40. err = dns.RcodeError(response.Rcode)
  41. } else if len(dns.MessageToAddresses(response)) == 0 {
  42. err = dns.RcodeSuccess
  43. }
  44. }
  45. select {
  46. case results <- queryResult{response, err}:
  47. case <-returned:
  48. }
  49. }
  50. queryCtx, queryCancel := context.WithCancel(ctx)
  51. defer queryCancel()
  52. var nameCount int
  53. for _, fqdn := range t.nameList(domain) {
  54. nameCount++
  55. go startRacer(queryCtx, fqdn)
  56. }
  57. var errors []error
  58. for {
  59. select {
  60. case <-ctx.Done():
  61. return nil, ctx.Err()
  62. case result := <-results:
  63. if result.err == nil {
  64. return result.response, nil
  65. }
  66. errors = append(errors, result.err)
  67. if len(errors) == nameCount {
  68. return nil, E.Errors(errors...)
  69. }
  70. }
  71. }
  72. }
  73. func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
  74. sLen := len(servers)
  75. var lastErr error
  76. for i := 0; i < t.attempts; i++ {
  77. for j := 0; j < sLen; j++ {
  78. server := servers[j]
  79. question := message.Question[0]
  80. question.Name = fqdn
  81. response, err := t.exchangeOne(ctx, server, question)
  82. if err != nil {
  83. lastErr = err
  84. continue
  85. }
  86. return response, nil
  87. }
  88. }
  89. return nil, E.Cause(lastErr, fqdn)
  90. }
  91. func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question) (*mDNS.Msg, error) {
  92. if server.Port == 0 {
  93. server.Port = 53
  94. }
  95. request := &mDNS.Msg{
  96. MsgHdr: mDNS.MsgHdr{
  97. Id: uint16(rand.Uint32()),
  98. RecursionDesired: true,
  99. AuthenticatedData: true,
  100. },
  101. Question: []mDNS.Question{question},
  102. Compress: true,
  103. }
  104. request.SetEdns0(buf.UDPBufferSize, false)
  105. return t.exchangeUDP(ctx, server, request)
  106. }
  107. func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
  108. conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
  109. if err != nil {
  110. return nil, err
  111. }
  112. defer conn.Close()
  113. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  114. conn.SetDeadline(deadline)
  115. }
  116. buffer := buf.Get(buf.UDPBufferSize)
  117. defer buf.Put(buffer)
  118. rawMessage, err := request.PackBuffer(buffer)
  119. if err != nil {
  120. return nil, E.Cause(err, "pack request")
  121. }
  122. _, err = conn.Write(rawMessage)
  123. if err != nil {
  124. if errors.Is(err, syscall.EMSGSIZE) {
  125. return t.exchangeTCP(ctx, server, request)
  126. }
  127. return nil, E.Cause(err, "write request")
  128. }
  129. n, err := conn.Read(buffer)
  130. if err != nil {
  131. if errors.Is(err, syscall.EMSGSIZE) {
  132. return t.exchangeTCP(ctx, server, request)
  133. }
  134. return nil, E.Cause(err, "read response")
  135. }
  136. var response mDNS.Msg
  137. err = response.Unpack(buffer[:n])
  138. if err != nil {
  139. return nil, E.Cause(err, "unpack response")
  140. }
  141. if response.Truncated {
  142. return t.exchangeTCP(ctx, server, request)
  143. }
  144. return &response, nil
  145. }
  146. func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
  147. conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
  148. if err != nil {
  149. return nil, err
  150. }
  151. defer conn.Close()
  152. if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
  153. conn.SetDeadline(deadline)
  154. }
  155. err = transport.WriteMessage(conn, 0, request)
  156. if err != nil {
  157. return nil, err
  158. }
  159. return transport.ReadMessage(conn)
  160. }
  161. func (t *Transport) nameList(name string) []string {
  162. l := len(name)
  163. rooted := l > 0 && name[l-1] == '.'
  164. if l > 254 || l == 254 && !rooted {
  165. return nil
  166. }
  167. if rooted {
  168. if avoidDNS(name) {
  169. return nil
  170. }
  171. return []string{name}
  172. }
  173. hasNdots := strings.Count(name, ".") >= t.ndots
  174. name += "."
  175. // l++
  176. names := make([]string, 0, 1+len(t.search))
  177. if hasNdots && !avoidDNS(name) {
  178. names = append(names, name)
  179. }
  180. for _, suffix := range t.search {
  181. fqdn := name + suffix
  182. if !avoidDNS(fqdn) && len(fqdn) <= 254 {
  183. names = append(names, fqdn)
  184. }
  185. }
  186. if !hasNdots && !avoidDNS(name) {
  187. names = append(names, name)
  188. }
  189. return names
  190. }
  191. func avoidDNS(name string) bool {
  192. if name == "" {
  193. return true
  194. }
  195. if name[len(name)-1] == '.' {
  196. name = name[:len(name)-1]
  197. }
  198. return strings.HasSuffix(name, ".onion")
  199. }