| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- // Copyright (c) Tailscale Inc & contributors
- // SPDX-License-Identifier: BSD-3-Clause
- package main
- import (
- "context"
- "encoding/binary"
- "encoding/json"
- "expvar"
- "log"
- "math/rand/v2"
- "net"
- "net/http"
- "net/netip"
- "strconv"
- "strings"
- "sync/atomic"
- "time"
- "tailscale.com/syncs"
- "tailscale.com/util/mak"
- "tailscale.com/util/slicesx"
- )
- const refreshTimeout = time.Minute
- type dnsEntryMap struct {
- IPs map[string][]net.IP
- Percent map[string]float64 // "foo.com" => 0.5 for 50%
- }
- var (
- dnsCache atomic.Pointer[dnsEntryMap]
- dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
- unpublishedDNSCache atomic.Pointer[dnsEntryMap]
- bootstrapLookupMap syncs.Map[string, bool]
- )
- var (
- bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
- publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
- publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
- unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
- unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
- unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses")
- )
- func init() {
- expvar.Publish("counter_bootstrap_dns_queried_domains", expvar.Func(func() any {
- return bootstrapLookupMap.Len()
- }))
- }
- func refreshBootstrapDNSLoop() {
- if *bootstrapDNS == "" && *unpublishedDNS == "" {
- return
- }
- for {
- refreshBootstrapDNS()
- refreshUnpublishedDNS()
- time.Sleep(10 * time.Minute)
- }
- }
- func refreshBootstrapDNS() {
- if *bootstrapDNS == "" {
- return
- }
- ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
- defer cancel()
- dnsEntries := resolveList(ctx, *bootstrapDNS)
- // Randomize the order of the IPs for each name to avoid the client biasing
- // to IPv6
- for _, vv := range dnsEntries.IPs {
- slicesx.Shuffle(vv)
- }
- j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t")
- if err != nil {
- // leave the old values in place
- return
- }
- dnsCache.Store(dnsEntries)
- dnsCacheBytes.Store(j)
- }
- func refreshUnpublishedDNS() {
- if *unpublishedDNS == "" {
- return
- }
- ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
- defer cancel()
- dnsEntries := resolveList(ctx, *unpublishedDNS)
- unpublishedDNSCache.Store(dnsEntries)
- }
- // resolveList takes a comma-separated list of DNS names to resolve.
- //
- // If an entry contains a slash, it's two DNS names: the first is the one to
- // resolve and the second is that of a TXT recording containing the rollout
- // percentage in range "0".."100". If the TXT record doesn't exist or is
- // malformed, the percentage is 0. If the TXT record is not provided (there's no
- // slash), then the percentage is 100.
- func resolveList(ctx context.Context, list string) *dnsEntryMap {
- ents := strings.Split(list, ",")
- ret := &dnsEntryMap{}
- var r net.Resolver
- for _, ent := range ents {
- name, txtName, _ := strings.Cut(ent, "/")
- addrs, err := r.LookupIP(ctx, "ip", name)
- if err != nil {
- log.Printf("bootstrap DNS lookup %q: %v", name, err)
- continue
- }
- mak.Set(&ret.IPs, name, addrs)
- if txtName == "" {
- mak.Set(&ret.Percent, name, 1.0)
- continue
- }
- vals, err := r.LookupTXT(ctx, txtName)
- if err != nil {
- log.Printf("bootstrap DNS lookup %q: %v", txtName, err)
- continue
- }
- for _, v := range vals {
- if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 {
- mak.Set(&ret.Percent, name, float64(v)/100)
- }
- }
- }
- return ret
- }
- func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
- bootstrapDNSRequests.Add(1)
- w.Header().Set("Content-Type", "application/json")
- // Bootstrap DNS requests occur cross-regions, and are randomized per
- // request, so keeping a connection open is pointlessly expensive.
- w.Header().Set("Connection", "close")
- // Try answering a query from our hidden map first
- if q := r.URL.Query().Get("q"); q != "" {
- bootstrapLookupMap.Store(q, true)
- if bootstrapLookupMap.Len() > 500 { // defensive
- bootstrapLookupMap.Clear()
- }
- if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 {
- unpublishedDNSHits.Add(1)
- percent := m.Percent[q]
- if remoteAddrMatchesPercent(r.RemoteAddr, percent) {
- // Only return the specific query, not everything.
- m := map[string][]net.IP{q: m.IPs[q]}
- j, err := json.MarshalIndent(m, "", "\t")
- if err == nil {
- w.Write(j)
- return
- }
- } else {
- unpublishedDNSPercentMisses.Add(1)
- }
- }
- // If we have a "q" query for a name in the published cache
- // list, then track whether that's a hit/miss.
- m := dnsCache.Load()
- var inPub bool
- var ips []net.IP
- if m != nil {
- ips, inPub = m.IPs[q]
- }
- if inPub {
- if len(ips) > 0 {
- publishedDNSHits.Add(1)
- } else {
- publishedDNSMisses.Add(1)
- }
- } else {
- // If it wasn't in either cache, treat this as a query
- // for the unpublished cache, and thus a cache miss.
- unpublishedDNSMisses.Add(1)
- }
- }
- // Fall back to returning the public set of cached DNS names
- j := dnsCacheBytes.Load()
- w.Write(j)
- }
- // percent is [0.0, 1.0].
- func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool {
- if percent == 0 {
- return false
- }
- if percent == 1 {
- return true
- }
- reqIPStr, _, err := net.SplitHostPort(remoteAddr)
- if err != nil {
- return false
- }
- reqIP, err := netip.ParseAddr(reqIPStr)
- if err != nil {
- return false
- }
- if reqIP.IsLoopback() {
- // For local testing.
- return rand.Float64() < 0.5
- }
- reqIP16 := reqIP.As16()
- rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:]))
- rnd := rand.New(rndSrc)
- return percent > rnd.Float64()
- }
|