bootstrap_dns.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package main
  4. import (
  5. "context"
  6. "encoding/json"
  7. "expvar"
  8. "log"
  9. "net"
  10. "net/http"
  11. "strings"
  12. "time"
  13. "tailscale.com/syncs"
  14. "tailscale.com/util/slicesx"
  15. )
  16. const refreshTimeout = time.Minute
  17. type dnsEntryMap map[string][]net.IP
  18. var (
  19. dnsCache syncs.AtomicValue[dnsEntryMap]
  20. dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
  21. unpublishedDNSCache syncs.AtomicValue[dnsEntryMap]
  22. )
  23. var (
  24. bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
  25. publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
  26. publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
  27. unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
  28. unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
  29. )
  30. func refreshBootstrapDNSLoop() {
  31. if *bootstrapDNS == "" && *unpublishedDNS == "" {
  32. return
  33. }
  34. for {
  35. refreshBootstrapDNS()
  36. refreshUnpublishedDNS()
  37. time.Sleep(10 * time.Minute)
  38. }
  39. }
  40. func refreshBootstrapDNS() {
  41. if *bootstrapDNS == "" {
  42. return
  43. }
  44. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  45. defer cancel()
  46. dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ","))
  47. // Randomize the order of the IPs for each name to avoid the client biasing
  48. // to IPv6
  49. for k := range dnsEntries {
  50. ips := dnsEntries[k]
  51. slicesx.Shuffle(ips)
  52. dnsEntries[k] = ips
  53. }
  54. j, err := json.MarshalIndent(dnsEntries, "", "\t")
  55. if err != nil {
  56. // leave the old values in place
  57. return
  58. }
  59. dnsCache.Store(dnsEntries)
  60. dnsCacheBytes.Store(j)
  61. }
  62. func refreshUnpublishedDNS() {
  63. if *unpublishedDNS == "" {
  64. return
  65. }
  66. ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
  67. defer cancel()
  68. dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ","))
  69. unpublishedDNSCache.Store(dnsEntries)
  70. }
  71. func resolveList(ctx context.Context, names []string) dnsEntryMap {
  72. dnsEntries := make(dnsEntryMap)
  73. var r net.Resolver
  74. for _, name := range names {
  75. addrs, err := r.LookupIP(ctx, "ip", name)
  76. if err != nil {
  77. log.Printf("bootstrap DNS lookup %q: %v", name, err)
  78. continue
  79. }
  80. dnsEntries[name] = addrs
  81. }
  82. return dnsEntries
  83. }
  84. func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
  85. bootstrapDNSRequests.Add(1)
  86. w.Header().Set("Content-Type", "application/json")
  87. // Bootstrap DNS requests occur cross-regions, and are randomized per
  88. // request, so keeping a connection open is pointlessly expensive.
  89. w.Header().Set("Connection", "close")
  90. // Try answering a query from our hidden map first
  91. if q := r.URL.Query().Get("q"); q != "" {
  92. if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 {
  93. unpublishedDNSHits.Add(1)
  94. // Only return the specific query, not everything.
  95. m := dnsEntryMap{q: ips}
  96. j, err := json.MarshalIndent(m, "", "\t")
  97. if err == nil {
  98. w.Write(j)
  99. return
  100. }
  101. }
  102. // If we have a "q" query for a name in the published cache
  103. // list, then track whether that's a hit/miss.
  104. if m, ok := dnsCache.Load()[q]; ok {
  105. if len(m) > 0 {
  106. publishedDNSHits.Add(1)
  107. } else {
  108. publishedDNSMisses.Add(1)
  109. }
  110. } else {
  111. // If it wasn't in either cache, treat this as a query
  112. // for the unpublished cache, and thus a cache miss.
  113. unpublishedDNSMisses.Add(1)
  114. }
  115. }
  116. // Fall back to returning the public set of cached DNS names
  117. j := dnsCacheBytes.Load()
  118. w.Write(j)
  119. }