lighthouse.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "sync"
  6. "time"
  7. "github.com/golang/protobuf/proto"
  8. "github.com/slackhq/nebula/cert"
  9. )
  10. type LightHouse struct {
  11. sync.RWMutex //Because we concurrently read and write to our maps
  12. amLighthouse bool
  13. myIp uint32
  14. punchConn *udpConn
  15. // Local cache of answers from light houses
  16. addrMap map[uint32][]udpAddr
  17. // staticList exists to avoid having a bool in each addrMap entry
  18. // since static should be rare
  19. staticList map[uint32]struct{}
  20. lighthouses map[uint32]struct{}
  21. interval int
  22. nebulaPort int
  23. punchBack bool
  24. punchDelay time.Duration
  25. }
  26. type EncWriter interface {
  27. SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
  28. SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
  29. }
  30. func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration) *LightHouse {
  31. h := LightHouse{
  32. amLighthouse: amLighthouse,
  33. myIp: myIp,
  34. addrMap: make(map[uint32][]udpAddr),
  35. nebulaPort: nebulaPort,
  36. lighthouses: make(map[uint32]struct{}),
  37. staticList: make(map[uint32]struct{}),
  38. interval: interval,
  39. punchConn: pc,
  40. punchBack: punchBack,
  41. punchDelay: punchDelay,
  42. }
  43. for _, ip := range ips {
  44. h.lighthouses[ip] = struct{}{}
  45. }
  46. return &h
  47. }
  48. func (lh *LightHouse) ValidateLHStaticEntries() error {
  49. for lhIP, _ := range lh.lighthouses {
  50. if _, ok := lh.staticList[lhIP]; !ok {
  51. return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP))
  52. }
  53. }
  54. return nil
  55. }
  56. func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) {
  57. if !lh.IsLighthouseIP(ip) {
  58. lh.QueryServer(ip, f)
  59. }
  60. lh.RLock()
  61. if v, ok := lh.addrMap[ip]; ok {
  62. lh.RUnlock()
  63. return v, nil
  64. }
  65. lh.RUnlock()
  66. return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip))
  67. }
  68. // This is asynchronous so no reply should be expected
  69. func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
  70. if !lh.amLighthouse {
  71. // Send a query to the lighthouses and hope for the best next time
  72. query, err := proto.Marshal(NewLhQueryByInt(ip))
  73. if err != nil {
  74. l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
  75. return
  76. }
  77. nb := make([]byte, 12, 12)
  78. out := make([]byte, mtu)
  79. for n := range lh.lighthouses {
  80. f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
  81. }
  82. }
  83. }
  84. // Query our local lighthouse cached results
  85. func (lh *LightHouse) QueryCache(ip uint32) []udpAddr {
  86. lh.RLock()
  87. if v, ok := lh.addrMap[ip]; ok {
  88. lh.RUnlock()
  89. return v
  90. }
  91. lh.RUnlock()
  92. return nil
  93. }
  94. func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
  95. // First we check the static mapping
  96. // and do nothing if it is there
  97. if _, ok := lh.staticList[vpnIP]; ok {
  98. return
  99. }
  100. lh.Lock()
  101. //l.Debugln(lh.addrMap)
  102. delete(lh.addrMap, vpnIP)
  103. l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
  104. lh.Unlock()
  105. }
  106. func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
  107. // First we check if the sender thinks this is a static entry
  108. // and do nothing if it is not, but should be considered static
  109. if static == false {
  110. if _, ok := lh.staticList[vpnIP]; ok {
  111. return
  112. }
  113. }
  114. lh.Lock()
  115. for _, v := range lh.addrMap[vpnIP] {
  116. if v.Equals(toIp) {
  117. lh.Unlock()
  118. return
  119. }
  120. }
  121. //l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
  122. if static {
  123. lh.staticList[vpnIP] = struct{}{}
  124. }
  125. lh.addrMap[vpnIP] = append(lh.addrMap[vpnIP], *toIp)
  126. lh.Unlock()
  127. }
  128. func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
  129. if lh.amLighthouse {
  130. lh.DeleteVpnIP(vpnIP)
  131. lh.AddRemote(vpnIP, toIp, false)
  132. }
  133. }
  134. func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
  135. if _, ok := lh.lighthouses[vpnIP]; ok {
  136. return true
  137. }
  138. return false
  139. }
  140. // Quick generators for protobuf
  141. func NewLhQueryByIpString(VpnIp string) *NebulaMeta {
  142. return NewLhQueryByInt(ip2int(net.ParseIP(VpnIp)))
  143. }
  144. func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
  145. return &NebulaMeta{
  146. Type: NebulaMeta_HostQuery,
  147. Details: &NebulaMetaDetails{
  148. VpnIp: VpnIp,
  149. },
  150. }
  151. }
  152. func NewLhWhoami() *NebulaMeta {
  153. return &NebulaMeta{
  154. Type: NebulaMeta_HostWhoami,
  155. Details: &NebulaMetaDetails{},
  156. }
  157. }
  158. // End Quick generators for protobuf
  159. func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort {
  160. return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)}
  161. }
  162. func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort {
  163. var iap []*IpAndPort
  164. for _, e := range ips {
  165. // Only add IPs that aren't my VPN/tun IP
  166. iap = append(iap, NewIpAndPortFromUDPAddr(e))
  167. }
  168. return &iap
  169. }
  170. func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
  171. if lh.amLighthouse || lh.interval == 0 {
  172. return
  173. }
  174. for {
  175. ipp := []*IpAndPort{}
  176. for _, e := range *localIps() {
  177. // Only add IPs that aren't my VPN/tun IP
  178. if ip2int(e) != lh.myIp {
  179. ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
  180. //fmt.Println(e)
  181. }
  182. }
  183. m := &NebulaMeta{
  184. Type: NebulaMeta_HostUpdateNotification,
  185. Details: &NebulaMetaDetails{
  186. VpnIp: lh.myIp,
  187. IpAndPorts: ipp,
  188. },
  189. }
  190. nb := make([]byte, 12, 12)
  191. out := make([]byte, mtu)
  192. for vpnIp := range lh.lighthouses {
  193. mm, err := proto.Marshal(m)
  194. if err != nil {
  195. l.Debugf("Invalid marshal to update")
  196. }
  197. //l.Error("LIGHTHOUSE PACKET SEND", mm)
  198. f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
  199. }
  200. time.Sleep(time.Second * time.Duration(lh.interval))
  201. }
  202. }
  203. func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) {
  204. n := &NebulaMeta{}
  205. err := proto.Unmarshal(p, n)
  206. if err != nil {
  207. l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
  208. Error("Failed to unmarshal lighthouse packet")
  209. //TODO: send recv_error?
  210. return
  211. }
  212. if n.Details == nil {
  213. l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
  214. Error("Invalid lighthouse update")
  215. //TODO: send recv_error?
  216. return
  217. }
  218. switch n.Type {
  219. case NebulaMeta_HostQuery:
  220. // Exit if we don't answer queries
  221. if !lh.amLighthouse {
  222. l.Debugln("I don't answer queries, but received from: ", rAddr)
  223. return
  224. }
  225. //l.Debugln("Got Query")
  226. ips, err := lh.Query(n.Details.VpnIp, f)
  227. if err != nil {
  228. //l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err)
  229. return
  230. } else {
  231. iap := NewIpAndPortsFromNetIps(ips)
  232. answer := &NebulaMeta{
  233. Type: NebulaMeta_HostQueryReply,
  234. Details: &NebulaMetaDetails{
  235. VpnIp: n.Details.VpnIp,
  236. IpAndPorts: *iap,
  237. },
  238. }
  239. reply, err := proto.Marshal(answer)
  240. if err != nil {
  241. l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
  242. return
  243. }
  244. f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
  245. // This signals the other side to punch some zero byte udp packets
  246. ips, err = lh.Query(vpnIp, f)
  247. if err != nil {
  248. l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
  249. return
  250. } else {
  251. //l.Debugln("Notify host to punch", iap)
  252. iap = NewIpAndPortsFromNetIps(ips)
  253. answer = &NebulaMeta{
  254. Type: NebulaMeta_HostPunchNotification,
  255. Details: &NebulaMetaDetails{
  256. VpnIp: vpnIp,
  257. IpAndPorts: *iap,
  258. },
  259. }
  260. reply, _ := proto.Marshal(answer)
  261. f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
  262. }
  263. //fmt.Println(reply, remoteaddr)
  264. }
  265. case NebulaMeta_HostQueryReply:
  266. if !lh.IsLighthouseIP(vpnIp) {
  267. return
  268. }
  269. for _, a := range n.Details.IpAndPorts {
  270. //first := n.Details.IpAndPorts[0]
  271. ans := NewUDPAddr(a.Ip, uint16(a.Port))
  272. lh.AddRemote(n.Details.VpnIp, ans, false)
  273. }
  274. case NebulaMeta_HostUpdateNotification:
  275. //Simple check that the host sent this not someone else
  276. if n.Details.VpnIp != vpnIp {
  277. l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
  278. return
  279. }
  280. for _, a := range n.Details.IpAndPorts {
  281. ans := NewUDPAddr(a.Ip, uint16(a.Port))
  282. lh.AddRemote(n.Details.VpnIp, ans, false)
  283. }
  284. case NebulaMeta_HostMovedNotification:
  285. case NebulaMeta_HostPunchNotification:
  286. if !lh.IsLighthouseIP(vpnIp) {
  287. return
  288. }
  289. empty := []byte{0}
  290. for _, a := range n.Details.IpAndPorts {
  291. vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
  292. go func() {
  293. time.Sleep(lh.punchDelay)
  294. lh.punchConn.WriteTo(empty, vpnPeer)
  295. }()
  296. l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
  297. }
  298. // This sends a nebula test packet to the host trying to contact us. In the case
  299. // of a double nat or other difficult scenario, this may help establish
  300. // a tunnel.
  301. if lh.punchBack {
  302. go func() {
  303. time.Sleep(time.Second * 5)
  304. l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
  305. f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
  306. }()
  307. }
  308. }
  309. }
  310. /*
  311. func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
  312. c := ci.messageCounter
  313. b := HeaderEncode(nil, Version, uint8(path_check), 0, ci.remoteIndex, c)
  314. ci.messageCounter++
  315. if ci.eKey != nil {
  316. msg := ci.eKey.EncryptDanger(b, nil, []byte(strconv.Itoa(counter)), c)
  317. //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
  318. f.outside.WriteTo(msg, endpoint)
  319. l.Debugf("path_check sent, remote index: %d, pathCounter %d", ci.remoteIndex, counter)
  320. }
  321. }
  322. func (f *Interface) sendPathCheckReply(ci *ConnectionState, endpoint *net.UDPAddr, counter []byte) {
  323. c := ci.messageCounter
  324. b := HeaderEncode(nil, Version, uint8(path_check_reply), 0, ci.remoteIndex, c)
  325. ci.messageCounter++
  326. if ci.eKey != nil {
  327. msg := ci.eKey.EncryptDanger(b, nil, counter, c)
  328. f.outside.WriteTo(msg, endpoint)
  329. l.Debugln("path_check sent, remote index: ", ci.remoteIndex)
  330. }
  331. }
  332. */