dns_server.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "net/netip"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "github.com/miekg/dns"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/config"
  12. )
  13. // This whole thing should be rewritten to use context
  14. var dnsR *dnsRecords
  15. var dnsServer *dns.Server
  16. var dnsAddr string
  17. type dnsRecords struct {
  18. sync.RWMutex
  19. dnsMap map[string]string
  20. hostMap *HostMap
  21. }
  22. func newDnsRecords(hostMap *HostMap) *dnsRecords {
  23. return &dnsRecords{
  24. dnsMap: make(map[string]string),
  25. hostMap: hostMap,
  26. }
  27. }
  28. func (d *dnsRecords) Query(data string) string {
  29. d.RLock()
  30. defer d.RUnlock()
  31. if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
  32. return r
  33. }
  34. return ""
  35. }
  36. func (d *dnsRecords) QueryCert(data string) string {
  37. ip, err := netip.ParseAddr(data[:len(data)-1])
  38. if err != nil {
  39. return ""
  40. }
  41. hostinfo := d.hostMap.QueryVpnIp(ip)
  42. if hostinfo == nil {
  43. return ""
  44. }
  45. q := hostinfo.GetCert()
  46. if q == nil {
  47. return ""
  48. }
  49. b, err := q.Certificate.MarshalJSON()
  50. if err != nil {
  51. return ""
  52. }
  53. return string(b)
  54. }
  55. func (d *dnsRecords) Add(host, data string) {
  56. d.Lock()
  57. defer d.Unlock()
  58. d.dnsMap[strings.ToLower(host)] = data
  59. }
  60. func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
  61. for _, q := range m.Question {
  62. switch q.Qtype {
  63. case dns.TypeA:
  64. l.Debugf("Query for A %s", q.Name)
  65. ip := dnsR.Query(q.Name)
  66. if ip != "" {
  67. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
  68. if err == nil {
  69. m.Answer = append(m.Answer, rr)
  70. }
  71. }
  72. case dns.TypeTXT:
  73. a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
  74. b, err := netip.ParseAddr(a)
  75. if err != nil {
  76. return
  77. }
  78. // We don't answer these queries from non nebula nodes or localhost
  79. //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
  80. if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
  81. return
  82. }
  83. l.Debugf("Query for TXT %s", q.Name)
  84. ip := dnsR.QueryCert(q.Name)
  85. if ip != "" {
  86. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  87. if err == nil {
  88. m.Answer = append(m.Answer, rr)
  89. }
  90. }
  91. }
  92. }
  93. if len(m.Answer) == 0 {
  94. m.Rcode = dns.RcodeNameError
  95. }
  96. }
  97. func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
  98. m := new(dns.Msg)
  99. m.SetReply(r)
  100. m.Compress = false
  101. switch r.Opcode {
  102. case dns.OpcodeQuery:
  103. parseQuery(l, m, w)
  104. }
  105. w.WriteMsg(m)
  106. }
  107. func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
  108. dnsR = newDnsRecords(hostMap)
  109. // attach request handler func
  110. dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
  111. handleDnsRequest(l, w, r)
  112. })
  113. c.RegisterReloadCallback(func(c *config.C) {
  114. reloadDns(l, c)
  115. })
  116. return func() {
  117. startDns(l, c)
  118. }
  119. }
  120. func getDnsServerAddr(c *config.C) string {
  121. dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
  122. // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
  123. if dnsHost == "[::]" {
  124. dnsHost = "::"
  125. }
  126. return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
  127. }
  128. func startDns(l *logrus.Logger, c *config.C) {
  129. dnsAddr = getDnsServerAddr(c)
  130. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  131. l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
  132. err := dnsServer.ListenAndServe()
  133. defer dnsServer.Shutdown()
  134. if err != nil {
  135. l.Errorf("Failed to start server: %s\n ", err.Error())
  136. }
  137. }
  138. func reloadDns(l *logrus.Logger, c *config.C) {
  139. if dnsAddr == getDnsServerAddr(c) {
  140. l.Debug("No DNS server config change detected")
  141. return
  142. }
  143. l.Debug("Restarting DNS server")
  144. dnsServer.Shutdown()
  145. go startDns(l, c)
  146. }