Bläddra i källkod

tailcfg: add DiscoKey, unify some code, add some tests

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 år sedan
förälder
incheckning
88c305c8af
2 ändrade filer med 83 tillägg och 41 borttagningar
  1. 32 40
      tailcfg/tailcfg.go
  2. 51 1
      tailcfg/tailcfg_test.go

+ 32 - 40
tailcfg/tailcfg.go

@@ -13,7 +13,9 @@ import (
 	"time"
 
 	"github.com/tailscale/wireguard-go/wgcfg"
+	"go4.org/mem"
 	"golang.org/x/oauth2"
+	"tailscale.com/types/key"
 	"tailscale.com/types/opt"
 	"tailscale.com/types/structs"
 )
@@ -38,6 +40,10 @@ type MachineKey [32]byte
 // NodeKey is the curve25519 public key for a node.
 type NodeKey [32]byte
 
+// DiscoKey is the curve25519 public key for path discovery key.
+// It's never written to disk or reused between network start-ups.
+type DiscoKey [32]byte
+
 type Group struct {
 	ID      GroupID
 	Name    string
@@ -127,6 +133,7 @@ type Node struct {
 	Key        NodeKey
 	KeyExpiry  time.Time
 	Machine    MachineKey
+	DiscoKey   DiscoKey
 	Addresses  []wgcfg.CIDR // IP addresses of this Node directly
 	AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node
 	Endpoints  []string     `json:",omitempty"` // IP+port (public via STUN, and local LANs)
@@ -519,59 +526,43 @@ type Debug struct {
 	LogHeapURL string `json:",omitempty"`
 }
 
-func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
+func (k MachineKey) String() string                   { return fmt.Sprintf("mkey:%x", k[:]) }
+func (k MachineKey) MarshalText() ([]byte, error)     { return keyMarshalText("mkey:", k), nil }
+func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) }
 
-func (k MachineKey) MarshalText() ([]byte, error) {
-	buf := new(bytes.Buffer)
-	fmt.Fprintf(buf, "mkey:%x", k[:])
-	return buf.Bytes(), nil
+func keyMarshalText(prefix string, k [32]byte) []byte {
+	buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+64))
+	fmt.Fprintf(buf, "%s%x", prefix, k[:])
+	return buf.Bytes()
 }
 
-func (k *MachineKey) UnmarshalText(text []byte) error {
-	s := string(text)
-	if !strings.HasPrefix(s, "mkey:") {
-		return errors.New(`MachineKey.UnmarshalText: missing prefix`)
+func keyUnmarshalText(dst []byte, prefix string, text []byte) error {
+	if len(text) < len(prefix) || string(text[:len(prefix)]) != prefix {
+		return fmt.Errorf("UnmarshalText: missing %q prefix", prefix)
 	}
-	s = strings.TrimPrefix(s, `mkey:`)
-	key, err := wgcfg.ParseHexKey(s)
+	pub, err := key.NewPublicFromHexMem(mem.B(text[len(prefix):]))
 	if err != nil {
-		return fmt.Errorf("MachineKey.UnmarhsalText: %v", err)
+		return fmt.Errorf("UnmarshalText: after %q: %v", prefix, err)
 	}
-	copy(k[:], key[:])
+	copy(dst[:], pub[:])
 	return nil
 }
 
-func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
+func (k NodeKey) ShortString() string { return (key.Public(k)).ShortString() }
 
-func (k NodeKey) ShortString() string {
-	pk := wgcfg.Key(k)
-	return pk.ShortString()
-}
+func (k NodeKey) String() string                   { return fmt.Sprintf("nodekey:%x", k[:]) }
+func (k NodeKey) MarshalText() ([]byte, error)     { return keyMarshalText("nodekey:", k), nil }
+func (k *NodeKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "nodekey:", text) }
 
-func (k NodeKey) MarshalText() ([]byte, error) {
-	buf := new(bytes.Buffer)
-	fmt.Fprintf(buf, "nodekey:%x", k[:])
-	return buf.Bytes(), nil
-}
+// IsZero reports whether k is the zero value.
+func (k NodeKey) IsZero() bool { return k == NodeKey{} }
 
-func (k *NodeKey) UnmarshalText(text []byte) error {
-	s := string(text)
-	if !strings.HasPrefix(s, "nodekey:") {
-		return errors.New(`Nodekey.UnmarshalText: missing prefix`)
-	}
-	s = strings.TrimPrefix(s, "nodekey:")
-	key, err := wgcfg.ParseHexKey(s)
-	if err != nil {
-		return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err)
-	}
-	copy(k[:], key[:])
-	return nil
-}
+func (k DiscoKey) String() string                   { return fmt.Sprintf("discokey:%x", k[:]) }
+func (k DiscoKey) MarshalText() ([]byte, error)     { return keyMarshalText("discokey:", k), nil }
+func (k *DiscoKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "discokey:", text) }
 
-// IsZero reports whether k is the NodeKey zero value.
-func (k NodeKey) IsZero() bool {
-	return k == NodeKey{}
-}
+// IsZero reports whether k is the zero value.
+func (k DiscoKey) IsZero() bool { return k == DiscoKey{} }
 
 func (id ID) String() string           { return fmt.Sprintf("id:%x", int64(id)) }
 func (id UserID) String() string       { return fmt.Sprintf("userid:%x", int64(id)) }
@@ -593,6 +584,7 @@ func (n *Node) Equal(n2 *Node) bool {
 		n.Key == n2.Key &&
 		n.KeyExpiry.Equal(n2.KeyExpiry) &&
 		n.Machine == n2.Machine &&
+		n.DiscoKey == n2.DiscoKey &&
 		reflect.DeepEqual(n.Addresses, n2.Addresses) &&
 		reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) &&
 		reflect.DeepEqual(n.Endpoints, n2.Endpoints) &&

+ 51 - 1
tailcfg/tailcfg_test.go

@@ -5,7 +5,9 @@
 package tailcfg
 
 import (
+	"encoding"
 	"reflect"
+	"strings"
 	"testing"
 	"time"
 
@@ -176,7 +178,7 @@ func TestHostinfoEqual(t *testing.T) {
 }
 
 func TestNodeEqual(t *testing.T) {
-	nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
+	nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "DiscoKey", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
 	if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) {
 		t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
 			have, nodeHandles)
@@ -336,3 +338,51 @@ func TestNetInfoFields(t *testing.T) {
 			have, handled)
 	}
 }
+
+func TestMachineKeyMarshal(t *testing.T) {
+	var k1, k2 MachineKey
+	for i := range k1 {
+		k1[i] = byte(i)
+	}
+	testKey(t, "mkey:", k1, &k2)
+}
+
+func TestNodeKeyMarshal(t *testing.T) {
+	var k1, k2 NodeKey
+	for i := range k1 {
+		k1[i] = byte(i)
+	}
+	testKey(t, "nodekey:", k1, &k2)
+}
+
+func TestDiscoKeyMarshal(t *testing.T) {
+	var k1, k2 DiscoKey
+	for i := range k1 {
+		k1[i] = byte(i)
+	}
+	testKey(t, "discokey:", k1, &k2)
+}
+
+type keyIn interface {
+	String() string
+	MarshalText() ([]byte, error)
+}
+
+func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler) {
+	got, err := in.MarshalText()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err := out.UnmarshalText(got); err != nil {
+		t.Fatal(err)
+	}
+	if s := in.String(); string(got) != s {
+		t.Errorf("MarshalText = %q != String %q", got, s)
+	}
+	if !strings.HasPrefix(string(got), prefix) {
+		t.Errorf("%q didn't start with prefix %q", got, prefix)
+	}
+	if reflect.ValueOf(out).Elem().Interface() != in {
+		t.Errorf("mismatch after unmarshal")
+	}
+}