bootstrap_dns.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. // Copyright (c) Tailscale Inc & contributors
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package main
  4. import (
  5. "context"
  6. "encoding/binary"
  7. "encoding/json"
  8. "expvar"
  9. "log"
  10. "math/rand/v2"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "strconv"
  15. "strings"
  16. "sync/atomic"
  17. "time"
  18. "tailscale.com/syncs"
  19. "tailscale.com/util/mak"
  20. "tailscale.com/util/slicesx"
  21. )
  22. const refreshTimeout = time.Minute
  23. type dnsEntryMap struct {
  24. IPs map[string][]net.IP
  25. Percent map[string]float64 // "foo.com" => 0.5 for 50%
  26. }
  27. var (
  28. dnsCache atomic.Pointer[dnsEntryMap]
  29. dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
  30. unpublishedDNSCache atomic.Pointer[dnsEntryMap]
  31. bootstrapLookupMap syncs.Map[string, bool]
  32. )
  33. var (
  34. bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
  35. publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
  36. publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
  37. unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
  38. unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
  39. unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses")
  40. )
  41. func init() {
  42. expvar.Publish("counter_bootstrap_dns_queried_domains", expvar.Func(func() any {
  43. return bootstrapLookupMap.Len()
  44. }))
  45. }
  46. func refreshBootstrapDNSLoop() {
  47. if *bootstrapDNS == "" && *unpublishedDNS == "" {
  48. return
  49. }
  50. for {
  51. refreshBootstrapDNS()
  52. refreshUnpublishedDNS()
  53. time.Sleep(10 * time.Minute)
  54. }
  55. }
  56. func refreshBootstrapDNS() {
  57. if *bootstrapDNS == "" {
  58. return
  59. }
  60. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  61. defer cancel()
  62. dnsEntries := resolveList(ctx, *bootstrapDNS)
  63. // Randomize the order of the IPs for each name to avoid the client biasing
  64. // to IPv6
  65. for _, vv := range dnsEntries.IPs {
  66. slicesx.Shuffle(vv)
  67. }
  68. j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t")
  69. if err != nil {
  70. // leave the old values in place
  71. return
  72. }
  73. dnsCache.Store(dnsEntries)
  74. dnsCacheBytes.Store(j)
  75. }
  76. func refreshUnpublishedDNS() {
  77. if *unpublishedDNS == "" {
  78. return
  79. }
  80. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  81. defer cancel()
  82. dnsEntries := resolveList(ctx, *unpublishedDNS)
  83. unpublishedDNSCache.Store(dnsEntries)
  84. }
  85. // resolveList takes a comma-separated list of DNS names to resolve.
  86. //
  87. // If an entry contains a slash, it's two DNS names: the first is the one to
  88. // resolve and the second is that of a TXT recording containing the rollout
  89. // percentage in range "0".."100". If the TXT record doesn't exist or is
  90. // malformed, the percentage is 0. If the TXT record is not provided (there's no
  91. // slash), then the percentage is 100.
  92. func resolveList(ctx context.Context, list string) *dnsEntryMap {
  93. ents := strings.Split(list, ",")
  94. ret := &dnsEntryMap{}
  95. var r net.Resolver
  96. for _, ent := range ents {
  97. name, txtName, _ := strings.Cut(ent, "/")
  98. addrs, err := r.LookupIP(ctx, "ip", name)
  99. if err != nil {
  100. log.Printf("bootstrap DNS lookup %q: %v", name, err)
  101. continue
  102. }
  103. mak.Set(&ret.IPs, name, addrs)
  104. if txtName == "" {
  105. mak.Set(&ret.Percent, name, 1.0)
  106. continue
  107. }
  108. vals, err := r.LookupTXT(ctx, txtName)
  109. if err != nil {
  110. log.Printf("bootstrap DNS lookup %q: %v", txtName, err)
  111. continue
  112. }
  113. for _, v := range vals {
  114. if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 {
  115. mak.Set(&ret.Percent, name, float64(v)/100)
  116. }
  117. }
  118. }
  119. return ret
  120. }
  121. func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
  122. bootstrapDNSRequests.Add(1)
  123. w.Header().Set("Content-Type", "application/json")
  124. // Bootstrap DNS requests occur cross-regions, and are randomized per
  125. // request, so keeping a connection open is pointlessly expensive.
  126. w.Header().Set("Connection", "close")
  127. // Try answering a query from our hidden map first
  128. if q := r.URL.Query().Get("q"); q != "" {
  129. bootstrapLookupMap.Store(q, true)
  130. if bootstrapLookupMap.Len() > 500 { // defensive
  131. bootstrapLookupMap.Clear()
  132. }
  133. if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 {
  134. unpublishedDNSHits.Add(1)
  135. percent := m.Percent[q]
  136. if remoteAddrMatchesPercent(r.RemoteAddr, percent) {
  137. // Only return the specific query, not everything.
  138. m := map[string][]net.IP{q: m.IPs[q]}
  139. j, err := json.MarshalIndent(m, "", "\t")
  140. if err == nil {
  141. w.Write(j)
  142. return
  143. }
  144. } else {
  145. unpublishedDNSPercentMisses.Add(1)
  146. }
  147. }
  148. // If we have a "q" query for a name in the published cache
  149. // list, then track whether that's a hit/miss.
  150. m := dnsCache.Load()
  151. var inPub bool
  152. var ips []net.IP
  153. if m != nil {
  154. ips, inPub = m.IPs[q]
  155. }
  156. if inPub {
  157. if len(ips) > 0 {
  158. publishedDNSHits.Add(1)
  159. } else {
  160. publishedDNSMisses.Add(1)
  161. }
  162. } else {
  163. // If it wasn't in either cache, treat this as a query
  164. // for the unpublished cache, and thus a cache miss.
  165. unpublishedDNSMisses.Add(1)
  166. }
  167. }
  168. // Fall back to returning the public set of cached DNS names
  169. j := dnsCacheBytes.Load()
  170. w.Write(j)
  171. }
  172. // percent is [0.0, 1.0].
  173. func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool {
  174. if percent == 0 {
  175. return false
  176. }
  177. if percent == 1 {
  178. return true
  179. }
  180. reqIPStr, _, err := net.SplitHostPort(remoteAddr)
  181. if err != nil {
  182. return false
  183. }
  184. reqIP, err := netip.ParseAddr(reqIPStr)
  185. if err != nil {
  186. return false
  187. }
  188. if reqIP.IsLoopback() {
  189. // For local testing.
  190. return rand.Float64() < 0.5
  191. }
  192. reqIP16 := reqIP.As16()
  193. rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:]))
  194. rnd := rand.New(rndSrc)
  195. return percent > rnd.Float64()
  196. }