bootstrap_dns.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package main
  4. import (
  5. "context"
  6. "encoding/json"
  7. "expvar"
  8. "log"
  9. "net"
  10. "net/http"
  11. "strings"
  12. "time"
  13. "tailscale.com/syncs"
  14. "tailscale.com/util/slicesx"
  15. )
  16. const refreshTimeout = time.Minute
  17. type dnsEntryMap map[string][]net.IP
  18. var (
  19. dnsCache syncs.AtomicValue[dnsEntryMap]
  20. dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
  21. unpublishedDNSCache syncs.AtomicValue[dnsEntryMap]
  22. bootstrapLookupMap syncs.Map[string, bool]
  23. )
  24. var (
  25. bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
  26. publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
  27. publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
  28. unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
  29. unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
  30. )
  31. func init() {
  32. expvar.Publish("counter_bootstrap_dns_queried_domains", expvar.Func(func() any {
  33. return bootstrapLookupMap.Len()
  34. }))
  35. }
  36. func refreshBootstrapDNSLoop() {
  37. if *bootstrapDNS == "" && *unpublishedDNS == "" {
  38. return
  39. }
  40. for {
  41. refreshBootstrapDNS()
  42. refreshUnpublishedDNS()
  43. time.Sleep(10 * time.Minute)
  44. }
  45. }
  46. func refreshBootstrapDNS() {
  47. if *bootstrapDNS == "" {
  48. return
  49. }
  50. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  51. defer cancel()
  52. dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ","))
  53. // Randomize the order of the IPs for each name to avoid the client biasing
  54. // to IPv6
  55. for k := range dnsEntries {
  56. ips := dnsEntries[k]
  57. slicesx.Shuffle(ips)
  58. dnsEntries[k] = ips
  59. }
  60. j, err := json.MarshalIndent(dnsEntries, "", "\t")
  61. if err != nil {
  62. // leave the old values in place
  63. return
  64. }
  65. dnsCache.Store(dnsEntries)
  66. dnsCacheBytes.Store(j)
  67. }
  68. func refreshUnpublishedDNS() {
  69. if *unpublishedDNS == "" {
  70. return
  71. }
  72. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  73. defer cancel()
  74. dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ","))
  75. unpublishedDNSCache.Store(dnsEntries)
  76. }
  77. func resolveList(ctx context.Context, names []string) dnsEntryMap {
  78. dnsEntries := make(dnsEntryMap)
  79. var r net.Resolver
  80. for _, name := range names {
  81. addrs, err := r.LookupIP(ctx, "ip", name)
  82. if err != nil {
  83. log.Printf("bootstrap DNS lookup %q: %v", name, err)
  84. continue
  85. }
  86. dnsEntries[name] = addrs
  87. }
  88. return dnsEntries
  89. }
  90. func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
  91. bootstrapDNSRequests.Add(1)
  92. w.Header().Set("Content-Type", "application/json")
  93. // Bootstrap DNS requests occur cross-regions, and are randomized per
  94. // request, so keeping a connection open is pointlessly expensive.
  95. w.Header().Set("Connection", "close")
  96. // Try answering a query from our hidden map first
  97. if q := r.URL.Query().Get("q"); q != "" {
  98. bootstrapLookupMap.Store(q, true)
  99. if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 {
  100. unpublishedDNSHits.Add(1)
  101. // Only return the specific query, not everything.
  102. m := dnsEntryMap{q: ips}
  103. j, err := json.MarshalIndent(m, "", "\t")
  104. if err == nil {
  105. w.Write(j)
  106. return
  107. }
  108. }
  109. // If we have a "q" query for a name in the published cache
  110. // list, then track whether that's a hit/miss.
  111. if m, ok := dnsCache.Load()[q]; ok {
  112. if len(m) > 0 {
  113. publishedDNSHits.Add(1)
  114. } else {
  115. publishedDNSMisses.Add(1)
  116. }
  117. } else {
  118. // If it wasn't in either cache, treat this as a query
  119. // for the unpublished cache, and thus a cache miss.
  120. unpublishedDNSMisses.Add(1)
  121. }
  122. }
  123. // Fall back to returning the public set of cached DNS names
  124. j := dnsCacheBytes.Load()
  125. w.Write(j)
  126. }