| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 | 
							- package nebula
 
- import (
 
- 	"fmt"
 
- 	"net"
 
- 	"strconv"
 
- 	"strings"
 
- 	"sync"
 
- 	"github.com/miekg/dns"
 
- 	"github.com/sirupsen/logrus"
 
- 	"github.com/slackhq/nebula/config"
 
- 	"github.com/slackhq/nebula/iputil"
 
- )
 
- // This whole thing should be rewritten to use context
 
- var dnsR *dnsRecords
 
- var dnsServer *dns.Server
 
- var dnsAddr string
 
- type dnsRecords struct {
 
- 	sync.RWMutex
 
- 	dnsMap  map[string]string
 
- 	hostMap *HostMap
 
- }
 
- func newDnsRecords(hostMap *HostMap) *dnsRecords {
 
- 	return &dnsRecords{
 
- 		dnsMap:  make(map[string]string),
 
- 		hostMap: hostMap,
 
- 	}
 
- }
 
- func (d *dnsRecords) Query(data string) string {
 
- 	d.RLock()
 
- 	defer d.RUnlock()
 
- 	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
 
- 		return r
 
- 	}
 
- 	return ""
 
- }
 
- func (d *dnsRecords) QueryCert(data string) string {
 
- 	ip := net.ParseIP(data[:len(data)-1])
 
- 	if ip == nil {
 
- 		return ""
 
- 	}
 
- 	iip := iputil.Ip2VpnIp(ip)
 
- 	hostinfo := d.hostMap.QueryVpnIp(iip)
 
- 	if hostinfo == nil {
 
- 		return ""
 
- 	}
 
- 	q := hostinfo.GetCert()
 
- 	if q == nil {
 
- 		return ""
 
- 	}
 
- 	cert := q.Details
 
- 	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
 
- 	return c
 
- }
 
- func (d *dnsRecords) Add(host, data string) {
 
- 	d.Lock()
 
- 	defer d.Unlock()
 
- 	d.dnsMap[strings.ToLower(host)] = data
 
- }
 
- func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 
- 	for _, q := range m.Question {
 
- 		switch q.Qtype {
 
- 		case dns.TypeA:
 
- 			l.Debugf("Query for A %s", q.Name)
 
- 			ip := dnsR.Query(q.Name)
 
- 			if ip != "" {
 
- 				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
 
- 				if err == nil {
 
- 					m.Answer = append(m.Answer, rr)
 
- 				}
 
- 			}
 
- 		case dns.TypeTXT:
 
- 			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
 
- 			b := net.ParseIP(a)
 
- 			// We don't answer these queries from non nebula nodes or localhost
 
- 			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
 
- 			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
 
- 				return
 
- 			}
 
- 			l.Debugf("Query for TXT %s", q.Name)
 
- 			ip := dnsR.QueryCert(q.Name)
 
- 			if ip != "" {
 
- 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 
- 				if err == nil {
 
- 					m.Answer = append(m.Answer, rr)
 
- 				}
 
- 			}
 
- 		}
 
- 	}
 
- }
 
- func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
 
- 	m := new(dns.Msg)
 
- 	m.SetReply(r)
 
- 	m.Compress = false
 
- 	switch r.Opcode {
 
- 	case dns.OpcodeQuery:
 
- 		parseQuery(l, m, w)
 
- 	}
 
- 	w.WriteMsg(m)
 
- }
 
- func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
 
- 	dnsR = newDnsRecords(hostMap)
 
- 	// attach request handler func
 
- 	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
 
- 		handleDnsRequest(l, w, r)
 
- 	})
 
- 	c.RegisterReloadCallback(func(c *config.C) {
 
- 		reloadDns(l, c)
 
- 	})
 
- 	return func() {
 
- 		startDns(l, c)
 
- 	}
 
- }
 
- func getDnsServerAddr(c *config.C) string {
 
- 	return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
 
- }
 
- func startDns(l *logrus.Logger, c *config.C) {
 
- 	dnsAddr = getDnsServerAddr(c)
 
- 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 
- 	l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
 
- 	err := dnsServer.ListenAndServe()
 
- 	defer dnsServer.Shutdown()
 
- 	if err != nil {
 
- 		l.Errorf("Failed to start server: %s\n ", err.Error())
 
- 	}
 
- }
 
- func reloadDns(l *logrus.Logger, c *config.C) {
 
- 	if dnsAddr == getDnsServerAddr(c) {
 
- 		l.Debug("No DNS server config change detected")
 
- 		return
 
- 	}
 
- 	l.Debug("Restarting DNS server")
 
- 	dnsServer.Shutdown()
 
- 	go startDns(l, c)
 
- }
 
 
  |