client.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. package dns
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "strings"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing/common"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. "github.com/sagernet/sing/common/logger"
  13. M "github.com/sagernet/sing/common/metadata"
  14. "github.com/sagernet/sing/common/task"
  15. "github.com/sagernet/sing/contrab/freelru"
  16. "github.com/sagernet/sing/contrab/maphash"
  17. dns "github.com/miekg/dns"
  18. )
  19. var (
  20. ErrNoRawSupport = E.New("no raw query support by current transport")
  21. ErrNotCached = E.New("not cached")
  22. ErrResponseRejected = E.New("response rejected")
  23. ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached")
  24. )
  25. var _ adapter.DNSClient = (*Client)(nil)
  26. type Client struct {
  27. timeout time.Duration
  28. disableCache bool
  29. disableExpire bool
  30. independentCache bool
  31. rdrc adapter.RDRCStore
  32. initRDRCFunc func() adapter.RDRCStore
  33. logger logger.ContextLogger
  34. cache freelru.Cache[dns.Question, *dns.Msg]
  35. transportCache freelru.Cache[transportCacheKey, *dns.Msg]
  36. }
  37. type ClientOptions struct {
  38. Timeout time.Duration
  39. DisableCache bool
  40. DisableExpire bool
  41. IndependentCache bool
  42. CacheCapacity uint32
  43. RDRC func() adapter.RDRCStore
  44. Logger logger.ContextLogger
  45. }
  46. func NewClient(options ClientOptions) *Client {
  47. client := &Client{
  48. timeout: options.Timeout,
  49. disableCache: options.DisableCache,
  50. disableExpire: options.DisableExpire,
  51. independentCache: options.IndependentCache,
  52. initRDRCFunc: options.RDRC,
  53. logger: options.Logger,
  54. }
  55. if client.timeout == 0 {
  56. client.timeout = C.DNSTimeout
  57. }
  58. cacheCapacity := options.CacheCapacity
  59. if cacheCapacity < 1024 {
  60. cacheCapacity = 1024
  61. }
  62. if !client.disableCache {
  63. if !client.independentCache {
  64. client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
  65. } else {
  66. client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
  67. }
  68. }
  69. return client
  70. }
  71. type transportCacheKey struct {
  72. dns.Question
  73. transportTag string
  74. }
  75. func (c *Client) Start() {
  76. if c.initRDRCFunc != nil {
  77. c.rdrc = c.initRDRCFunc()
  78. }
  79. }
  80. func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
  81. if len(message.Question) == 0 {
  82. if c.logger != nil {
  83. c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
  84. }
  85. responseMessage := dns.Msg{
  86. MsgHdr: dns.MsgHdr{
  87. Id: message.Id,
  88. Response: true,
  89. Rcode: dns.RcodeFormatError,
  90. },
  91. Question: message.Question,
  92. }
  93. return &responseMessage, nil
  94. }
  95. question := message.Question[0]
  96. if options.ClientSubnet.IsValid() {
  97. message = SetClientSubnet(message, options.ClientSubnet, true)
  98. }
  99. isSimpleRequest := len(message.Question) == 1 &&
  100. len(message.Ns) == 0 &&
  101. len(message.Extra) == 0 &&
  102. !options.ClientSubnet.IsValid()
  103. disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
  104. if !disableCache {
  105. response, ttl := c.loadResponse(question, transport)
  106. if response != nil {
  107. logCachedResponse(c.logger, ctx, response, ttl)
  108. response.Id = message.Id
  109. return response, nil
  110. }
  111. }
  112. if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
  113. responseMessage := dns.Msg{
  114. MsgHdr: dns.MsgHdr{
  115. Id: message.Id,
  116. Response: true,
  117. Rcode: dns.RcodeSuccess,
  118. },
  119. Question: []dns.Question{question},
  120. }
  121. if c.logger != nil {
  122. c.logger.DebugContext(ctx, "strategy rejected")
  123. }
  124. return &responseMessage, nil
  125. }
  126. messageId := message.Id
  127. contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
  128. if clientSubnetLoaded && transport.Tag() == contextTransport {
  129. return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
  130. }
  131. ctx = contextWithTransportTag(ctx, transport.Tag())
  132. if responseChecker != nil && c.rdrc != nil {
  133. rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype)
  134. if rejected {
  135. return nil, ErrResponseRejectedCached
  136. }
  137. }
  138. ctx, cancel := context.WithTimeout(ctx, c.timeout)
  139. response, err := transport.Exchange(ctx, message)
  140. cancel()
  141. if err != nil {
  142. return nil, err
  143. }
  144. /*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
  145. validResponse := response
  146. loop:
  147. for {
  148. var (
  149. addresses int
  150. queryCNAME string
  151. )
  152. for _, rawRR := range validResponse.Answer {
  153. switch rr := rawRR.(type) {
  154. case *dns.A:
  155. break loop
  156. case *dns.AAAA:
  157. break loop
  158. case *dns.CNAME:
  159. queryCNAME = rr.Target
  160. }
  161. }
  162. if queryCNAME == "" {
  163. break
  164. }
  165. exMessage := *message
  166. exMessage.Question = []dns.Question{{
  167. Name: queryCNAME,
  168. Qtype: question.Qtype,
  169. }}
  170. validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
  171. if err != nil {
  172. return nil, err
  173. }
  174. }
  175. if validResponse != response {
  176. response.Answer = append(response.Answer, validResponse.Answer...)
  177. }
  178. }*/
  179. if responseChecker != nil {
  180. addr, addrErr := MessageToAddresses(response)
  181. if addrErr != nil || !responseChecker(addr) {
  182. if c.rdrc != nil {
  183. c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
  184. }
  185. logRejectedResponse(c.logger, ctx, response)
  186. return response, ErrResponseRejected
  187. }
  188. }
  189. if question.Qtype == dns.TypeHTTPS {
  190. if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
  191. for _, rr := range response.Answer {
  192. https, isHTTPS := rr.(*dns.HTTPS)
  193. if !isHTTPS {
  194. continue
  195. }
  196. content := https.SVCB
  197. content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
  198. if options.Strategy == C.DomainStrategyIPv4Only {
  199. return it.Key() != dns.SVCB_IPV6HINT
  200. } else {
  201. return it.Key() != dns.SVCB_IPV4HINT
  202. }
  203. })
  204. https.SVCB = content
  205. }
  206. }
  207. }
  208. var timeToLive uint32
  209. for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
  210. for _, record := range recordList {
  211. if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
  212. timeToLive = record.Header().Ttl
  213. }
  214. }
  215. }
  216. if options.RewriteTTL != nil {
  217. timeToLive = *options.RewriteTTL
  218. }
  219. for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
  220. for _, record := range recordList {
  221. record.Header().Ttl = timeToLive
  222. }
  223. }
  224. response.Id = messageId
  225. if !disableCache {
  226. c.storeCache(transport, question, response, timeToLive)
  227. }
  228. logExchangedResponse(c.logger, ctx, response, timeToLive)
  229. return response, err
  230. }
  231. func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
  232. domain = FqdnToDomain(domain)
  233. dnsName := dns.Fqdn(domain)
  234. if options.Strategy == C.DomainStrategyIPv4Only {
  235. return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
  236. } else if options.Strategy == C.DomainStrategyIPv6Only {
  237. return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
  238. }
  239. var response4 []netip.Addr
  240. var response6 []netip.Addr
  241. var group task.Group
  242. group.Append("exchange4", func(ctx context.Context) error {
  243. response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
  244. if err != nil {
  245. return err
  246. }
  247. response4 = response
  248. return nil
  249. })
  250. group.Append("exchange6", func(ctx context.Context) error {
  251. response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
  252. if err != nil {
  253. return err
  254. }
  255. response6 = response
  256. return nil
  257. })
  258. err := group.Run(ctx)
  259. if len(response4) == 0 && len(response6) == 0 {
  260. return nil, err
  261. }
  262. return sortAddresses(response4, response6, options.Strategy), nil
  263. }
  264. func (c *Client) ClearCache() {
  265. if c.cache != nil {
  266. c.cache.Purge()
  267. }
  268. if c.transportCache != nil {
  269. c.transportCache.Purge()
  270. }
  271. }
  272. func (c *Client) LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool) {
  273. if c.disableCache || c.independentCache {
  274. return nil, false
  275. }
  276. if dns.IsFqdn(domain) {
  277. domain = domain[:len(domain)-1]
  278. }
  279. dnsName := dns.Fqdn(domain)
  280. if strategy == C.DomainStrategyIPv4Only {
  281. response, err := c.questionCache(dns.Question{
  282. Name: dnsName,
  283. Qtype: dns.TypeA,
  284. Qclass: dns.ClassINET,
  285. }, nil)
  286. if err != ErrNotCached {
  287. return response, true
  288. }
  289. } else if strategy == C.DomainStrategyIPv6Only {
  290. response, err := c.questionCache(dns.Question{
  291. Name: dnsName,
  292. Qtype: dns.TypeAAAA,
  293. Qclass: dns.ClassINET,
  294. }, nil)
  295. if err != ErrNotCached {
  296. return response, true
  297. }
  298. } else {
  299. response4, _ := c.questionCache(dns.Question{
  300. Name: dnsName,
  301. Qtype: dns.TypeA,
  302. Qclass: dns.ClassINET,
  303. }, nil)
  304. response6, _ := c.questionCache(dns.Question{
  305. Name: dnsName,
  306. Qtype: dns.TypeAAAA,
  307. Qclass: dns.ClassINET,
  308. }, nil)
  309. if len(response4) > 0 || len(response6) > 0 {
  310. return sortAddresses(response4, response6, strategy), true
  311. }
  312. }
  313. return nil, false
  314. }
  315. func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
  316. if c.disableCache || c.independentCache || len(message.Question) != 1 {
  317. return nil, false
  318. }
  319. question := message.Question[0]
  320. response, ttl := c.loadResponse(question, nil)
  321. if response == nil {
  322. return nil, false
  323. }
  324. logCachedResponse(c.logger, ctx, response, ttl)
  325. response.Id = message.Id
  326. return response, true
  327. }
  328. func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr {
  329. if strategy == C.DomainStrategyPreferIPv6 {
  330. return append(response6, response4...)
  331. } else {
  332. return append(response4, response6...)
  333. }
  334. }
  335. func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, timeToLive uint32) {
  336. if timeToLive == 0 {
  337. return
  338. }
  339. if c.disableExpire {
  340. if !c.independentCache {
  341. c.cache.Add(question, message)
  342. } else {
  343. c.transportCache.Add(transportCacheKey{
  344. Question: question,
  345. transportTag: transport.Tag(),
  346. }, message)
  347. }
  348. return
  349. }
  350. if !c.independentCache {
  351. c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
  352. } else {
  353. c.transportCache.AddWithLifetime(transportCacheKey{
  354. Question: question,
  355. transportTag: transport.Tag(),
  356. }, message, time.Second*time.Duration(timeToLive))
  357. }
  358. }
  359. func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
  360. question := dns.Question{
  361. Name: name,
  362. Qtype: qType,
  363. Qclass: dns.ClassINET,
  364. }
  365. disableCache := c.disableCache || options.DisableCache
  366. if !disableCache {
  367. cachedAddresses, err := c.questionCache(question, transport)
  368. if err != ErrNotCached {
  369. return cachedAddresses, err
  370. }
  371. }
  372. message := dns.Msg{
  373. MsgHdr: dns.MsgHdr{
  374. RecursionDesired: true,
  375. },
  376. Question: []dns.Question{question},
  377. }
  378. response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
  379. if err != nil {
  380. return nil, err
  381. }
  382. return MessageToAddresses(response)
  383. }
  384. func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
  385. response, _ := c.loadResponse(question, transport)
  386. if response == nil {
  387. return nil, ErrNotCached
  388. }
  389. return MessageToAddresses(response)
  390. }
  391. func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
  392. var (
  393. response *dns.Msg
  394. loaded bool
  395. )
  396. if c.disableExpire {
  397. if !c.independentCache {
  398. response, loaded = c.cache.Get(question)
  399. } else {
  400. response, loaded = c.transportCache.Get(transportCacheKey{
  401. Question: question,
  402. transportTag: transport.Tag(),
  403. })
  404. }
  405. if !loaded {
  406. return nil, 0
  407. }
  408. return response.Copy(), 0
  409. } else {
  410. var expireAt time.Time
  411. if !c.independentCache {
  412. response, expireAt, loaded = c.cache.GetWithLifetime(question)
  413. } else {
  414. response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
  415. Question: question,
  416. transportTag: transport.Tag(),
  417. })
  418. }
  419. if !loaded {
  420. return nil, 0
  421. }
  422. timeNow := time.Now()
  423. if timeNow.After(expireAt) {
  424. if !c.independentCache {
  425. c.cache.Remove(question)
  426. } else {
  427. c.transportCache.Remove(transportCacheKey{
  428. Question: question,
  429. transportTag: transport.Tag(),
  430. })
  431. }
  432. return nil, 0
  433. }
  434. var originTTL int
  435. for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
  436. for _, record := range recordList {
  437. if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
  438. originTTL = int(record.Header().Ttl)
  439. }
  440. }
  441. }
  442. nowTTL := int(expireAt.Sub(timeNow).Seconds())
  443. if nowTTL < 0 {
  444. nowTTL = 0
  445. }
  446. response = response.Copy()
  447. if originTTL > 0 {
  448. duration := uint32(originTTL - nowTTL)
  449. for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
  450. for _, record := range recordList {
  451. record.Header().Ttl = record.Header().Ttl - duration
  452. }
  453. }
  454. } else {
  455. for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
  456. for _, record := range recordList {
  457. record.Header().Ttl = uint32(nowTTL)
  458. }
  459. }
  460. }
  461. return response, nowTTL
  462. }
  463. }
  464. func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) {
  465. if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
  466. return nil, RcodeError(response.Rcode)
  467. }
  468. addresses := make([]netip.Addr, 0, len(response.Answer))
  469. for _, rawAnswer := range response.Answer {
  470. switch answer := rawAnswer.(type) {
  471. case *dns.A:
  472. addresses = append(addresses, M.AddrFromIP(answer.A))
  473. case *dns.AAAA:
  474. addresses = append(addresses, M.AddrFromIP(answer.AAAA))
  475. case *dns.HTTPS:
  476. for _, value := range answer.SVCB.Value {
  477. if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
  478. addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
  479. }
  480. }
  481. }
  482. }
  483. return addresses, nil
  484. }
  485. func wrapError(err error) error {
  486. switch dnsErr := err.(type) {
  487. case *net.DNSError:
  488. if dnsErr.IsNotFound {
  489. return RcodeNameError
  490. }
  491. case *net.AddrError:
  492. return RcodeNameError
  493. }
  494. return err
  495. }
  496. type transportKey struct{}
  497. func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
  498. return context.WithValue(ctx, transportKey{}, transportTag)
  499. }
  500. func transportTagFromContext(ctx context.Context) (string, bool) {
  501. value, loaded := ctx.Value(transportKey{}).(string)
  502. return value, loaded
  503. }
  504. func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, timeToLive uint32) *dns.Msg {
  505. response := dns.Msg{
  506. MsgHdr: dns.MsgHdr{
  507. Id: id,
  508. Rcode: dns.RcodeSuccess,
  509. Response: true,
  510. },
  511. Question: []dns.Question{question},
  512. }
  513. for _, address := range addresses {
  514. if address.Is4() && question.Qtype == dns.TypeA {
  515. response.Answer = append(response.Answer, &dns.A{
  516. Hdr: dns.RR_Header{
  517. Name: question.Name,
  518. Rrtype: dns.TypeA,
  519. Class: dns.ClassINET,
  520. Ttl: timeToLive,
  521. },
  522. A: address.AsSlice(),
  523. })
  524. } else if address.Is6() && question.Qtype == dns.TypeAAAA {
  525. response.Answer = append(response.Answer, &dns.AAAA{
  526. Hdr: dns.RR_Header{
  527. Name: question.Name,
  528. Rrtype: dns.TypeAAAA,
  529. Class: dns.ClassINET,
  530. Ttl: timeToLive,
  531. },
  532. AAAA: address.AsSlice(),
  533. })
  534. }
  535. }
  536. return &response
  537. }
  538. func FixedResponseCNAME(id uint16, question dns.Question, record string, timeToLive uint32) *dns.Msg {
  539. response := dns.Msg{
  540. MsgHdr: dns.MsgHdr{
  541. Id: id,
  542. Rcode: dns.RcodeSuccess,
  543. Response: true,
  544. },
  545. Question: []dns.Question{question},
  546. Answer: []dns.RR{
  547. &dns.CNAME{
  548. Hdr: dns.RR_Header{
  549. Name: question.Name,
  550. Rrtype: dns.TypeCNAME,
  551. Class: dns.ClassINET,
  552. Ttl: timeToLive,
  553. },
  554. Target: record,
  555. },
  556. },
  557. }
  558. return &response
  559. }
  560. func FixedResponseTXT(id uint16, question dns.Question, records []string, timeToLive uint32) *dns.Msg {
  561. response := dns.Msg{
  562. MsgHdr: dns.MsgHdr{
  563. Id: id,
  564. Rcode: dns.RcodeSuccess,
  565. Response: true,
  566. },
  567. Question: []dns.Question{question},
  568. Answer: []dns.RR{
  569. &dns.TXT{
  570. Hdr: dns.RR_Header{
  571. Name: question.Name,
  572. Rrtype: dns.TypeA,
  573. Class: dns.ClassINET,
  574. Ttl: timeToLive,
  575. },
  576. Txt: records,
  577. },
  578. },
  579. }
  580. return &response
  581. }
  582. func FixedResponseMX(id uint16, question dns.Question, records []*net.MX, timeToLive uint32) *dns.Msg {
  583. response := dns.Msg{
  584. MsgHdr: dns.MsgHdr{
  585. Id: id,
  586. Rcode: dns.RcodeSuccess,
  587. Response: true,
  588. },
  589. Question: []dns.Question{question},
  590. }
  591. for _, record := range records {
  592. response.Answer = append(response.Answer, &dns.MX{
  593. Hdr: dns.RR_Header{
  594. Name: question.Name,
  595. Rrtype: dns.TypeA,
  596. Class: dns.ClassINET,
  597. Ttl: timeToLive,
  598. },
  599. Preference: record.Pref,
  600. Mx: record.Host,
  601. })
  602. }
  603. return &response
  604. }