nameserver_udp.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/xtls/xray-core/common"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/log"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/protocol/dns"
  14. udp_proto "github.com/xtls/xray-core/common/protocol/udp"
  15. "github.com/xtls/xray-core/common/task"
  16. dns_feature "github.com/xtls/xray-core/features/dns"
  17. "github.com/xtls/xray-core/features/routing"
  18. "github.com/xtls/xray-core/transport/internet/udp"
  19. "golang.org/x/net/dns/dnsmessage"
  20. )
  21. // ClassicNameServer implemented traditional UDP DNS.
  22. type ClassicNameServer struct {
  23. sync.RWMutex
  24. cacheController *CacheController
  25. address *net.Destination
  26. requests map[uint16]*udpDnsRequest
  27. udpServer *udp.Dispatcher
  28. requestsCleanup *task.Periodic
  29. reqID uint32
  30. clientIP net.IP
  31. }
  32. type udpDnsRequest struct {
  33. dnsRequest
  34. ctx context.Context
  35. }
  36. // NewClassicNameServer creates udp server object for remote resolving.
  37. func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) *ClassicNameServer {
  38. // default to 53 if unspecific
  39. if address.Port == 0 {
  40. address.Port = net.Port(53)
  41. }
  42. s := &ClassicNameServer{
  43. cacheController: NewCacheController(strings.ToUpper(address.String()), disableCache),
  44. address: &address,
  45. requests: make(map[uint16]*udpDnsRequest),
  46. clientIP: clientIP,
  47. }
  48. s.requestsCleanup = &task.Periodic{
  49. Interval: time.Minute,
  50. Execute: s.RequestsCleanup,
  51. }
  52. s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
  53. errors.LogInfo(context.Background(), "DNS: created UDP client initialized for ", address.NetAddr())
  54. return s
  55. }
  56. // Name implements Server.
  57. func (s *ClassicNameServer) Name() string {
  58. return s.cacheController.name
  59. }
  60. // RequestsCleanup clears expired items from cache
  61. func (s *ClassicNameServer) RequestsCleanup() error {
  62. now := time.Now()
  63. s.Lock()
  64. defer s.Unlock()
  65. if len(s.requests) == 0 {
  66. return errors.New(s.Name(), " nothing to do. stopping...")
  67. }
  68. for id, req := range s.requests {
  69. if req.expire.Before(now) {
  70. delete(s.requests, id)
  71. }
  72. }
  73. if len(s.requests) == 0 {
  74. s.requests = make(map[uint16]*udpDnsRequest)
  75. }
  76. return nil
  77. }
  78. // HandleResponse handles udp response packet from remote DNS server.
  79. func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
  80. ipRec, err := parseResponse(packet.Payload.Bytes())
  81. if err != nil {
  82. errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
  83. return
  84. }
  85. s.Lock()
  86. id := ipRec.ReqID
  87. req, ok := s.requests[id]
  88. if ok {
  89. // remove the pending request
  90. delete(s.requests, id)
  91. }
  92. s.Unlock()
  93. if !ok {
  94. errors.LogError(ctx, s.Name(), " cannot find the pending request")
  95. return
  96. }
  97. // if truncated, retry with EDNS0 option(udp payload size: 1350)
  98. if ipRec.RawHeader.Truncated {
  99. // if already has EDNS0 option, no need to retry
  100. if len(req.msg.Additionals) == 0 {
  101. // copy necessary meta data from original request
  102. // and add EDNS0 option
  103. opt := new(dnsmessage.Resource)
  104. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  105. opt.Body = &dnsmessage.OPTResource{}
  106. newMsg := *req.msg
  107. newReq := *req
  108. newMsg.Additionals = append(newMsg.Additionals, *opt)
  109. newMsg.ID = s.newReqID()
  110. newReq.msg = &newMsg
  111. s.addPendingRequest(&newReq)
  112. b, _ := dns.PackMessage(newReq.msg)
  113. s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
  114. return
  115. }
  116. }
  117. s.cacheController.updateIP(&req.dnsRequest, ipRec)
  118. }
  119. func (s *ClassicNameServer) newReqID() uint16 {
  120. return uint16(atomic.AddUint32(&s.reqID, 1))
  121. }
  122. func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
  123. s.Lock()
  124. id := req.msg.ID
  125. req.expire = time.Now().Add(time.Second * 8)
  126. s.requests[id] = req
  127. s.Unlock()
  128. common.Must(s.requestsCleanup.Start())
  129. }
  130. func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
  131. errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
  132. reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
  133. for _, req := range reqs {
  134. udpReq := &udpDnsRequest{
  135. dnsRequest: *req,
  136. ctx: ctx,
  137. }
  138. s.addPendingRequest(udpReq)
  139. b, _ := dns.PackMessage(req.msg)
  140. s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
  141. }
  142. }
  143. // QueryIP implements Server.
  144. func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
  145. fqdn := Fqdn(domain)
  146. sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
  147. defer closeSubscribers(sub4, sub6)
  148. if s.cacheController.disableCache {
  149. errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
  150. } else {
  151. ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
  152. if !go_errors.Is(err, errRecordNotFound) {
  153. errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
  154. log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
  155. return ips, ttl, err
  156. }
  157. }
  158. noResponseErrCh := make(chan error, 2)
  159. s.sendQuery(ctx, noResponseErrCh, fqdn, option)
  160. start := time.Now()
  161. if sub4 != nil {
  162. select {
  163. case <-ctx.Done():
  164. return nil, 0, ctx.Err()
  165. case err := <-noResponseErrCh:
  166. return nil, 0, err
  167. case <-sub4.Wait():
  168. sub4.Close()
  169. }
  170. }
  171. if sub6 != nil {
  172. select {
  173. case <-ctx.Done():
  174. return nil, 0, ctx.Err()
  175. case err := <-noResponseErrCh:
  176. return nil, 0, err
  177. case <-sub6.Wait():
  178. sub6.Close()
  179. }
  180. }
  181. ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
  182. log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
  183. return ips, ttl, err
  184. }