connection_manager_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package nebula
  2. import (
  3. "context"
  4. "crypto/ed25519"
  5. "crypto/rand"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/flynn/noise"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/slackhq/nebula/udp"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. func newTestLighthouse() *LightHouse {
  17. lh := &LightHouse{
  18. l: test.NewLogger(),
  19. addrMap: map[netip.Addr]*RemoteList{},
  20. queryChan: make(chan netip.Addr, 10),
  21. }
  22. lighthouses := map[netip.Addr]struct{}{}
  23. staticList := map[netip.Addr]struct{}{}
  24. lh.lighthouses.Store(&lighthouses)
  25. lh.staticList.Store(&staticList)
  26. return lh
  27. }
  28. func Test_NewConnectionManagerTest(t *testing.T) {
  29. l := test.NewLogger()
  30. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  31. vpncidr := netip.MustParsePrefix("172.1.1.1/24")
  32. localrange := netip.MustParsePrefix("10.1.1.1/24")
  33. vpnIp := netip.MustParseAddr("172.1.1.2")
  34. preferredRanges := []netip.Prefix{localrange}
  35. // Very incomplete mock objects
  36. hostMap := newHostMap(l, vpncidr)
  37. hostMap.preferredRanges.Store(&preferredRanges)
  38. cs := &CertState{
  39. RawCertificate: []byte{},
  40. PrivateKey: []byte{},
  41. Certificate: &dummyCert{},
  42. RawCertificateNoKey: []byte{},
  43. }
  44. lh := newTestLighthouse()
  45. ifce := &Interface{
  46. hostMap: hostMap,
  47. inside: &test.NoopTun{},
  48. outside: &udp.NoopConn{},
  49. firewall: &Firewall{},
  50. lightHouse: lh,
  51. pki: &PKI{},
  52. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  53. l: l,
  54. }
  55. ifce.pki.cs.Store(cs)
  56. // Create manager
  57. ctx, cancel := context.WithCancel(context.Background())
  58. defer cancel()
  59. punchy := NewPunchyFromConfig(l, config.NewC(l))
  60. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  61. p := []byte("")
  62. nb := make([]byte, 12, 12)
  63. out := make([]byte, mtu)
  64. // Add an ip we have established a connection w/ to hostmap
  65. hostinfo := &HostInfo{
  66. vpnIp: vpnIp,
  67. localIndexId: 1099,
  68. remoteIndexId: 9901,
  69. }
  70. hostinfo.ConnectionState = &ConnectionState{
  71. myCert: &dummyCert{},
  72. H: &noise.HandshakeState{},
  73. }
  74. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  75. // We saw traffic out to vpnIp
  76. nc.Out(hostinfo.localIndexId)
  77. nc.In(hostinfo.localIndexId)
  78. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  79. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  80. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  81. assert.Contains(t, nc.out, hostinfo.localIndexId)
  82. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  83. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  84. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  85. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  86. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  87. // Do another traffic check tick, this host should be pending deletion now
  88. nc.Out(hostinfo.localIndexId)
  89. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  90. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  91. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  92. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  93. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  94. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  95. // Do a final traffic check tick, the host should now be removed
  96. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  97. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  98. assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  99. assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  100. }
  101. func Test_NewConnectionManagerTest2(t *testing.T) {
  102. l := test.NewLogger()
  103. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  104. vpncidr := netip.MustParsePrefix("172.1.1.1/24")
  105. localrange := netip.MustParsePrefix("10.1.1.1/24")
  106. vpnIp := netip.MustParseAddr("172.1.1.2")
  107. preferredRanges := []netip.Prefix{localrange}
  108. // Very incomplete mock objects
  109. hostMap := newHostMap(l, vpncidr)
  110. hostMap.preferredRanges.Store(&preferredRanges)
  111. cs := &CertState{
  112. RawCertificate: []byte{},
  113. PrivateKey: []byte{},
  114. Certificate: &dummyCert{},
  115. RawCertificateNoKey: []byte{},
  116. }
  117. lh := newTestLighthouse()
  118. ifce := &Interface{
  119. hostMap: hostMap,
  120. inside: &test.NoopTun{},
  121. outside: &udp.NoopConn{},
  122. firewall: &Firewall{},
  123. lightHouse: lh,
  124. pki: &PKI{},
  125. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  126. l: l,
  127. }
  128. ifce.pki.cs.Store(cs)
  129. // Create manager
  130. ctx, cancel := context.WithCancel(context.Background())
  131. defer cancel()
  132. punchy := NewPunchyFromConfig(l, config.NewC(l))
  133. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  134. p := []byte("")
  135. nb := make([]byte, 12, 12)
  136. out := make([]byte, mtu)
  137. // Add an ip we have established a connection w/ to hostmap
  138. hostinfo := &HostInfo{
  139. vpnIp: vpnIp,
  140. localIndexId: 1099,
  141. remoteIndexId: 9901,
  142. }
  143. hostinfo.ConnectionState = &ConnectionState{
  144. myCert: &dummyCert{},
  145. H: &noise.HandshakeState{},
  146. }
  147. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  148. // We saw traffic out to vpnIp
  149. nc.Out(hostinfo.localIndexId)
  150. nc.In(hostinfo.localIndexId)
  151. assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
  152. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  153. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  154. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  155. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  156. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  157. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  158. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  159. // Do another traffic check tick, this host should be pending deletion now
  160. nc.Out(hostinfo.localIndexId)
  161. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  162. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  163. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  164. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  165. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  166. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  167. // We saw traffic, should no longer be pending deletion
  168. nc.In(hostinfo.localIndexId)
  169. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  170. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  171. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  172. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  173. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  174. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
  175. }
  176. // Check if we can disconnect the peer.
  177. // Validate if the peer's certificate is invalid (expired, etc.)
  178. // Disconnect only if disconnectInvalid: true is set.
  179. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
  180. now := time.Now()
  181. l := test.NewLogger()
  182. vpncidr := netip.MustParsePrefix("172.1.1.1/24")
  183. localrange := netip.MustParsePrefix("10.1.1.1/24")
  184. vpnIp := netip.MustParseAddr("172.1.1.2")
  185. preferredRanges := []netip.Prefix{localrange}
  186. hostMap := newHostMap(l, vpncidr)
  187. hostMap.preferredRanges.Store(&preferredRanges)
  188. // Generate keys for CA and peer's cert.
  189. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
  190. tbs := &cert.TBSCertificate{
  191. Version: 1,
  192. Name: "ca",
  193. IsCA: true,
  194. NotBefore: now,
  195. NotAfter: now.Add(1 * time.Hour),
  196. PublicKey: pubCA,
  197. }
  198. caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
  199. assert.NoError(t, err)
  200. ncp := cert.NewCAPool()
  201. assert.NoError(t, ncp.AddCA(caCert))
  202. pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
  203. tbs = &cert.TBSCertificate{
  204. Version: 1,
  205. Name: "host",
  206. Networks: []netip.Prefix{vpncidr},
  207. NotBefore: now,
  208. NotAfter: now.Add(60 * time.Second),
  209. PublicKey: pubCrt,
  210. }
  211. peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
  212. assert.NoError(t, err)
  213. cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
  214. cs := &CertState{
  215. RawCertificate: []byte{},
  216. PrivateKey: []byte{},
  217. Certificate: &dummyCert{},
  218. RawCertificateNoKey: []byte{},
  219. }
  220. lh := newTestLighthouse()
  221. ifce := &Interface{
  222. hostMap: hostMap,
  223. inside: &test.NoopTun{},
  224. outside: &udp.NoopConn{},
  225. firewall: &Firewall{},
  226. lightHouse: lh,
  227. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  228. l: l,
  229. pki: &PKI{},
  230. }
  231. ifce.pki.cs.Store(cs)
  232. ifce.pki.caPool.Store(ncp)
  233. ifce.disconnectInvalid.Store(true)
  234. // Create manager
  235. ctx, cancel := context.WithCancel(context.Background())
  236. defer cancel()
  237. punchy := NewPunchyFromConfig(l, config.NewC(l))
  238. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  239. ifce.connectionManager = nc
  240. hostinfo := &HostInfo{
  241. vpnIp: vpnIp,
  242. ConnectionState: &ConnectionState{
  243. myCert: &dummyCert{},
  244. peerCert: cachedPeerCert,
  245. H: &noise.HandshakeState{},
  246. },
  247. }
  248. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  249. // Move ahead 45s.
  250. // Check if to disconnect with invalid certificate.
  251. // Should be alive.
  252. nextTick := now.Add(45 * time.Second)
  253. invalid := nc.isInvalidCertificate(nextTick, hostinfo)
  254. assert.False(t, invalid)
  255. // Move ahead 61s.
  256. // Check if to disconnect with invalid certificate.
  257. // Should be disconnected.
  258. nextTick = now.Add(61 * time.Second)
  259. invalid = nc.isInvalidCertificate(nextTick, hostinfo)
  260. assert.True(t, invalid)
  261. }
  262. type dummyCert struct {
  263. version cert.Version
  264. curve cert.Curve
  265. groups []string
  266. isCa bool
  267. issuer string
  268. name string
  269. networks []netip.Prefix
  270. notAfter time.Time
  271. notBefore time.Time
  272. publicKey []byte
  273. signature []byte
  274. unsafeNetworks []netip.Prefix
  275. }
  276. func (d *dummyCert) Version() cert.Version {
  277. return d.version
  278. }
  279. func (d *dummyCert) Curve() cert.Curve {
  280. return d.curve
  281. }
  282. func (d *dummyCert) Groups() []string {
  283. return d.groups
  284. }
  285. func (d *dummyCert) IsCA() bool {
  286. return d.isCa
  287. }
  288. func (d *dummyCert) Issuer() string {
  289. return d.issuer
  290. }
  291. func (d *dummyCert) Name() string {
  292. return d.name
  293. }
  294. func (d *dummyCert) Networks() []netip.Prefix {
  295. return d.networks
  296. }
  297. func (d *dummyCert) NotAfter() time.Time {
  298. return d.notAfter
  299. }
  300. func (d *dummyCert) NotBefore() time.Time {
  301. return d.notBefore
  302. }
  303. func (d *dummyCert) PublicKey() []byte {
  304. return d.publicKey
  305. }
  306. func (d *dummyCert) Signature() []byte {
  307. return d.signature
  308. }
  309. func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
  310. return d.unsafeNetworks
  311. }
  312. func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
  313. return nil, nil
  314. }
  315. func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
  316. return nil
  317. }
  318. func (d *dummyCert) CheckSignature(key []byte) bool {
  319. return true
  320. }
  321. func (d *dummyCert) Expired(t time.Time) bool {
  322. return false
  323. }
  324. func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
  325. return nil
  326. }
  327. func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
  328. return nil
  329. }
  330. func (d *dummyCert) String() string {
  331. return ""
  332. }
  333. func (d *dummyCert) Marshal() ([]byte, error) {
  334. return nil, nil
  335. }
  336. func (d *dummyCert) MarshalPEM() ([]byte, error) {
  337. return nil, nil
  338. }
  339. func (d *dummyCert) Fingerprint() (string, error) {
  340. return "", nil
  341. }
  342. func (d *dummyCert) MarshalJSON() ([]byte, error) {
  343. return nil, nil
  344. }
  345. func (d *dummyCert) Copy() cert.Certificate {
  346. return d
  347. }