Browse Source

try to make certificate addition/removal reloadable in some cases (#1468)

* try to make certificate addition/removal reloadable in some cases

* very spicy change to respond to handshakes with cert versions we cannot match with a cert that we can indeed match

* even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

* make tryRehandshake easier to understand
Jack Doan 4 weeks ago
parent
commit
01909f4715
7 changed files with 526 additions and 80 deletions
  1. 27 0
      cert_test/cert.go
  2. 58 22
      connection_manager.go
  3. 103 0
      e2e/helpers_test.go
  4. 263 0
      e2e/tunnels_test.go
  5. 25 16
      handshake_ix.go
  6. 6 5
      handshake_manager.go
  7. 44 37
      pki.go

+ 27 - 0
cert_test/cert.go

@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
 	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
 }
 
+func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
+	nc := &cert.TBSCertificate{
+		Version:        v,
+		Curve:          c.Curve(),
+		Name:           c.Name(),
+		Networks:       c.Networks(),
+		UnsafeNetworks: c.UnsafeNetworks(),
+		Groups:         c.Groups(),
+		NotBefore:      time.Unix(c.NotBefore().Unix(), 0),
+		NotAfter:       time.Unix(c.NotAfter().Unix(), 0),
+		PublicKey:      c.PublicKey(),
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pem
+}
+
 func X25519Keypair() ([]byte, []byte) {
 	privkey := make([]byte, 32)
 	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

+ 58 - 22
connection_manager.go

@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
 
 		if mainHostInfo {
 			decision = tryRehandshake
-
 		} else {
 			if cm.shouldSwapPrimary(hostinfo) {
 				decision = swapPrimary
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
 	}
 
 	crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	if crt == nil {
+		//my cert was reloaded away. We should definitely swap from this tunnel
+		return true
+	}
 	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
 	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
 	cm.hostMap.Unlock()
 }
 
-// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
-// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
-// check and return true.
+// isInvalidCertificate decides if we should destroy a tunnel.
+// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
+// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
 func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
-		return false
+		return false //don't tear down tunnels for handshakes in progress
 	}
 
 	caPool := cm.intf.pki.GetCAPool()
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	if err == nil {
-		return false
-	}
-
-	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+		return false //cert is still valid! yay!
+	} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
 		// Block listed certificates should always be disconnected
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is blocked, tearing down the tunnel")
+		return true
+	} else if cm.intf.disconnectInvalid.Load() {
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is no longer valid, tearing down the tunnel")
+		return true
+	} else {
+		//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
 		return false
 	}
-
-	hostinfo.logger(cm.l).WithError(err).
-		WithField("fingerprint", remoteCert.Fingerprint).
-		Info("Remote certificate is no longer valid, tearing down the tunnel")
-
-	return true
 }
 
 func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
 func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
-	myCrt := cs.getCertificate(curCrt.Version())
-	if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
-		// The current tunnel is using the latest certificate and version, no need to rehandshake.
+	curCrtVersion := curCrt.Version()
+	myCrt := cs.getCertificate(curCrtVersion)
+	if myCrt == nil {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("version", curCrtVersion).
+			WithField("reason", "local certificate removed").
+			Info("Re-handshaking with remote")
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 		return
 	}
+	peerCrt := hostinfo.ConnectionState.peerCert
+	if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
+		// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
+		if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
+			cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+				WithField("version", curCrtVersion).
+				WithField("peerVersion", peerCrt.Certificate.Version()).
+				WithField("reason", "local certificate version lower than peer, attempting to correct").
+				Info("Re-handshaking with remote")
+			cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
+				hh.initiatingVersionOverride = peerCrt.Certificate.Version()
+			})
+			return
+		}
+	}
+	if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("reason", "local certificate is not current").
+			Info("Re-handshaking with remote")
 
-	cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
-		WithField("reason", "local certificate is not current").
-		Info("Re-handshaking with remote")
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		return
+	}
+	if curCrtVersion < cs.initiatingVersion {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("reason", "current cert version < pki.initiatingVersion").
+			Info("Re-handshaking with remote")
 
-	cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		return
+	}
 }

+ 103 - 0
e2e/helpers_test.go

@@ -129,6 +129,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 	return control, vpnNetworks, udpAddr, c
 }
 
+// newServer creates a nebula instance with fewer assumptions
+func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	vpnNetworks := certs[len(certs)-1].Networks()
+
+	var udpAddr netip.AddrPort
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
+		budpIp[1] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+	} else {
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
+	}
+
+	caStr := ""
+	for _, ca := range caCrt {
+		x, err := ca.MarshalPEM()
+		if err != nil {
+			panic(err)
+		}
+		caStr += string(x)
+	}
+	certStr := ""
+	for _, c := range certs {
+		x, err := c.MarshalPEM()
+		if err != nil {
+			panic(err)
+		}
+		certStr += string(x)
+	}
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": certStr,
+			"key":  string(key),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": m{
+			"outbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+			"inbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+		},
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
+		"listen": m{
+			"host": udpAddr.Addr().String(),
+			"port": udpAddr.Port(),
+		},
+		"logging": m{
+			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
+			"level":            l.Level.String(),
+		},
+		"timers": m{
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
+		},
+	}
+
+	if overrides != nil {
+		final := m{}
+		err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = final
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	c := config.NewC(l)
+	cStr := string(cb)
+	c.LoadString(cStr)
+
+	control, err := nebula.Main(c, false, "e2e-test", l, nil)
+
+	if err != nil {
+		panic(err)
+	}
+
+	return control, vpnNetworks, udpAddr, c
+}
+
 type doneCb func()
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {

+ 263 - 0
e2e/tunnels_test.go

@@ -4,12 +4,16 @@
 package e2e
 
 import (
+	"fmt"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v3"
 )
 
 func TestDropInactiveTunnels(t *testing.T) {
@@ -55,3 +59,262 @@ func TestDropInactiveTunnels(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCertUpgrade(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	caB, err := ca.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	ca2B, err := ca2.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
+
+	myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": string(myCert2Pem),
+			"key":  string(myPrivKey),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": myC.Settings["firewall"],
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
+		"listen":  myC.Settings["listen"],
+		"logging": myC.Settings["logging"],
+		"timers":  myC.Settings["timers"],
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	r.Logf("reload new v2-only config")
+	err = myC.ReloadConfigString(string(cb))
+	assert.NoError(t, err)
+	r.Log("yay, spin until their sees it")
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		if c == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			r.Logf("version %d", version)
+			if version == cert.Version2 {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*10 {
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestCertDowngrade(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	caB, err := ca.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	ca2B, err := ca2.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
+
+	myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	//r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": string(myCertPem),
+			"key":  string(myPrivKey),
+		},
+		"firewall": myC.Settings["firewall"],
+		"listen":   myC.Settings["listen"],
+		"logging":  myC.Settings["logging"],
+		"timers":   myC.Settings["timers"],
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	r.Logf("reload new v1-only config")
+	err = myC.ReloadConfigString(string(cb))
+	assert.NoError(t, err)
+	r.Log("yay, spin until their sees it")
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
+			if version == cert.Version1 {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("it is unusual that the cert is not new yet, but not a failure yet")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestCertMismatchCorrection(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	//r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
+			if version == theirVersion {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("wtf")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 25 - 16
handshake_ix.go

@@ -23,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 		return false
 	}
 
-	// If we're connecting to a v6 address we must use a v2 cert
 	cs := f.pki.getCertState()
 	v := cs.initiatingVersion
-	for _, a := range hh.hostinfo.vpnAddrs {
-		if a.Is6() {
-			v = cert.Version2
-			break
+	if hh.initiatingVersionOverride != cert.VersionPre1 {
+		v = hh.initiatingVersionOverride
+	} else if v < cert.Version2 {
+		// If we're connecting to a v6 address we should encourage use of a V2 cert
+		for _, a := range hh.hostinfo.vpnAddrs {
+			if a.Is6() {
+				v = cert.Version2
+				break
+			}
 		}
 	}
 
@@ -48,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("certVersion", v).
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return false
 	}
 
 	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -103,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("certVersion", cs.initiatingVersion).
 			Error("Unable to handshake with host because no certificate is available")
+		return
 	}
 
 	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -143,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
 	if err != nil {
-		fp, err := rc.Fingerprint()
-		if err != nil {
+		fp, fperr := rc.Fingerprint()
+		if fperr != nil {
 			fp = "<error generating certificate fingerprint>"
 		}
 
@@ -163,16 +169,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	if remoteCert.Certificate.Version() != ci.myCert.Version() {
 		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
-		rc := cs.getCertificate(remoteCert.Certificate.Version())
-		if rc == nil {
-			f.l.WithError(err).WithField("udpAddr", addr).
-				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
-				Info("Unable to handshake with host due to missing certificate version")
-			return
+		myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
+		if myCertOtherVersion == nil {
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithError(err).WithFields(m{
+					"udpAddr":   addr,
+					"handshake": m{"stage": 1, "style": "ix_psk0"},
+					"cert":      remoteCert,
+				}).Debug("Might be unable to handshake with host due to missing certificate version")
+			}
+		} else {
+			// Record the certificate we are actually using
+			ci.myCert = myCertOtherVersion
 		}
-
-		// Record the certificate we are actually using
-		ci.myCert = rc
 	}
 
 	if len(remoteCert.Certificate.Networks()) == 0 {

+ 6 - 5
handshake_manager.go

@@ -68,11 +68,12 @@ type HandshakeManager struct {
 type HandshakeHostInfo struct {
 	sync.Mutex
 
-	startTime   time.Time        // Time that we first started trying with this handshake
-	ready       bool             // Is the handshake ready
-	counter     int64            // How many attempts have we made so far
-	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
-	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
+	startTime                 time.Time        // Time that we first started trying with this handshake
+	ready                     bool             // Is the handshake ready
+	initiatingVersionOverride cert.Version     // Should we use a non-default cert version for this handshake?
+	counter                   int64            // How many attempts have we made so far
+	lastRemotes               []netip.AddrPort // Remotes that we sent to during the previous attempt
+	packetStore               []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
 	hostinfo *HostInfo
 }

+ 44 - 37
pki.go

@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 		currentState := p.cs.Load()
 		if newState.v1Cert != nil {
 			if currentState.v1Cert == nil {
-				return util.NewContextualError("v1 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
+						nil,
+					)
+				}
+
+				if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new v1 cert was different from old",
+						m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
+						nil,
+					)
+				}
 			}
-
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
-					nil,
-				)
-			}
-
-			if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
-				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
-					nil,
-				)
-			}
-
-		} else if currentState.v1Cert != nil {
-			//TODO: CERT-V2 we should be able to tear this down
-			return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
 		}
 
 		if newState.v2Cert != nil {
 			if currentState.v2Cert == nil {
-				return util.NewContextualError("v2 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
+						nil,
+					)
+				}
+
+				if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new cert was different from old",
+						m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
+						nil,
+					)
+				}
 			}
 
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
-					nil,
-				)
+		} else if currentState.v2Cert != nil {
+			//newState.v1Cert is non-nil bc empty certstates aren't permitted
+			if newState.v1Cert == nil {
+				return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
 			}
-
-			if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+			//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
+			if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
 				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+					"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
+					m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
 					nil,
 				)
 			}
-
-		} else if currentState.v2Cert != nil {
-			return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
 		}
 
 		// Cipher cant be hot swapped so just leave it at what it was before