nameserver_udp.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/xtls/xray-core/common"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/log"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/protocol/dns"
  14. udp_proto "github.com/xtls/xray-core/common/protocol/udp"
  15. "github.com/xtls/xray-core/common/task"
  16. dns_feature "github.com/xtls/xray-core/features/dns"
  17. "github.com/xtls/xray-core/features/routing"
  18. "github.com/xtls/xray-core/transport/internet/udp"
  19. "golang.org/x/net/dns/dnsmessage"
  20. )
  21. // ClassicNameServer implemented traditional UDP DNS.
  22. type ClassicNameServer struct {
  23. sync.RWMutex
  24. cacheController *CacheController
  25. address *net.Destination
  26. requests map[uint16]*udpDnsRequest
  27. udpServer *udp.Dispatcher
  28. requestsCleanup *task.Periodic
  29. reqID uint32
  30. clientIP net.IP
  31. }
  32. type udpDnsRequest struct {
  33. dnsRequest
  34. ctx context.Context
  35. }
  36. // NewClassicNameServer creates udp server object for remote resolving.
  37. func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, disableCache bool, serveStale bool, serveExpiredTTL uint32, clientIP net.IP) *ClassicNameServer {
  38. // default to 53 if unspecific
  39. if address.Port == 0 {
  40. address.Port = net.Port(53)
  41. }
  42. s := &ClassicNameServer{
  43. cacheController: NewCacheController(strings.ToUpper(address.String()), disableCache, serveStale, serveExpiredTTL),
  44. address: &address,
  45. requests: make(map[uint16]*udpDnsRequest),
  46. clientIP: clientIP,
  47. }
  48. s.requestsCleanup = &task.Periodic{
  49. Interval: time.Minute,
  50. Execute: s.RequestsCleanup,
  51. }
  52. s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
  53. errors.LogInfo(context.Background(), "DNS: created UDP client initialized for ", address.NetAddr())
  54. return s
  55. }
  56. // Name implements Server.
  57. func (s *ClassicNameServer) Name() string {
  58. return s.cacheController.name
  59. }
  60. // RequestsCleanup clears expired items from cache
  61. func (s *ClassicNameServer) RequestsCleanup() error {
  62. now := time.Now()
  63. s.Lock()
  64. defer s.Unlock()
  65. if len(s.requests) == 0 {
  66. return errors.New(s.Name(), " nothing to do. stopping...")
  67. }
  68. for id, req := range s.requests {
  69. if req.expire.Before(now) {
  70. delete(s.requests, id)
  71. }
  72. }
  73. if len(s.requests) == 0 {
  74. s.requests = make(map[uint16]*udpDnsRequest)
  75. }
  76. return nil
  77. }
  78. // HandleResponse handles udp response packet from remote DNS server.
  79. func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
  80. payload := packet.Payload
  81. ipRec, err := parseResponse(payload.Bytes())
  82. payload.Release()
  83. if err != nil {
  84. errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
  85. return
  86. }
  87. s.Lock()
  88. id := ipRec.ReqID
  89. req, ok := s.requests[id]
  90. if ok {
  91. // remove the pending request
  92. delete(s.requests, id)
  93. }
  94. s.Unlock()
  95. if !ok {
  96. errors.LogError(ctx, s.Name(), " cannot find the pending request")
  97. return
  98. }
  99. // if truncated, retry with EDNS0 option(udp payload size: 1350)
  100. if ipRec.RawHeader.Truncated {
  101. // if already has EDNS0 option, no need to retry
  102. if len(req.msg.Additionals) == 0 {
  103. // copy necessary meta data from original request
  104. // and add EDNS0 option
  105. opt := new(dnsmessage.Resource)
  106. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  107. opt.Body = &dnsmessage.OPTResource{}
  108. newMsg := *req.msg
  109. newReq := *req
  110. newMsg.Additionals = append(newMsg.Additionals, *opt)
  111. newMsg.ID = s.newReqID()
  112. newReq.msg = &newMsg
  113. s.addPendingRequest(&newReq)
  114. b, _ := dns.PackMessage(newReq.msg)
  115. copyDest := net.UDPDestination(s.address.Address, s.address.Port)
  116. b.UDP = &copyDest
  117. s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
  118. return
  119. }
  120. }
  121. s.cacheController.updateIP(&req.dnsRequest, ipRec)
  122. }
  123. func (s *ClassicNameServer) newReqID() uint16 {
  124. return uint16(atomic.AddUint32(&s.reqID, 1))
  125. }
  126. func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
  127. s.Lock()
  128. id := req.msg.ID
  129. req.expire = time.Now().Add(time.Second * 8)
  130. s.requests[id] = req
  131. s.Unlock()
  132. common.Must(s.requestsCleanup.Start())
  133. }
  134. func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
  135. errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
  136. reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
  137. for _, req := range reqs {
  138. udpReq := &udpDnsRequest{
  139. dnsRequest: *req,
  140. ctx: ctx,
  141. }
  142. s.addPendingRequest(udpReq)
  143. b, _ := dns.PackMessage(req.msg)
  144. copyDest := net.UDPDestination(s.address.Address, s.address.Port)
  145. b.UDP = &copyDest
  146. s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
  147. }
  148. }
  149. // QueryIP implements Server.
  150. func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
  151. fqdn := Fqdn(domain)
  152. sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
  153. defer closeSubscribers(sub4, sub6)
  154. queryOption := option
  155. if s.cacheController.disableCache {
  156. errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
  157. } else {
  158. ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
  159. if sub4 != nil && !isARecordExpired {
  160. sub4.Close()
  161. sub4 = nil
  162. queryOption.IPv4Enable = false
  163. }
  164. if sub6 != nil && !isAAAARecordExpired {
  165. sub6.Close()
  166. sub6 = nil
  167. queryOption.IPv6Enable = false
  168. }
  169. if !go_errors.Is(err, errRecordNotFound) {
  170. if ttl > 0 {
  171. errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
  172. log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
  173. return ips, uint32(ttl), err
  174. }
  175. if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
  176. errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
  177. s.sendQuery(ctx, nil, fqdn, queryOption)
  178. return ips, 1, err
  179. }
  180. }
  181. }
  182. noResponseErrCh := make(chan error, 2)
  183. s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
  184. start := time.Now()
  185. if sub4 != nil {
  186. select {
  187. case <-ctx.Done():
  188. return nil, 0, ctx.Err()
  189. case err := <-noResponseErrCh:
  190. return nil, 0, err
  191. case <-sub4.Wait():
  192. sub4.Close()
  193. }
  194. }
  195. if sub6 != nil {
  196. select {
  197. case <-ctx.Done():
  198. return nil, 0, ctx.Err()
  199. case err := <-noResponseErrCh:
  200. return nil, 0, err
  201. case <-sub6.Wait():
  202. sub6.Close()
  203. }
  204. }
  205. ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
  206. log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
  207. var rTTL uint32
  208. if ttl <= 0 {
  209. rTTL = 1
  210. } else {
  211. rTTL = uint32(ttl)
  212. }
  213. return ips, rTTL, err
  214. }