|
|
@@ -0,0 +1,238 @@
|
|
|
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
|
|
+// Use of this source code is governed by a BSD-style
|
|
|
+// license that can be found in the LICENSE file.
|
|
|
+
|
|
|
+//go:build windows && cgo
|
|
|
+// +build windows,cgo
|
|
|
+
|
|
|
+package controlclient
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto"
|
|
|
+ "crypto/x509"
|
|
|
+ "crypto/x509/pkix"
|
|
|
+ "errors"
|
|
|
+ "reflect"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/tailscale/certstore"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ testRootCommonName = "testroot"
|
|
|
+ testRootSubject = "CN=testroot"
|
|
|
+)
|
|
|
+
|
|
|
+type testIdentity struct {
|
|
|
+ chain []*x509.Certificate
|
|
|
+}
|
|
|
+
|
|
|
+func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
|
|
|
+ return []*x509.Certificate{
|
|
|
+ {
|
|
|
+ NotBefore: notBefore,
|
|
|
+ NotAfter: notAfter,
|
|
|
+ PublicKeyAlgorithm: x509.RSA,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Subject: pkix.Name{
|
|
|
+ CommonName: rootCommonName,
|
|
|
+ },
|
|
|
+ PublicKeyAlgorithm: x509.RSA,
|
|
|
+ },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (t *testIdentity) Certificate() (*x509.Certificate, error) {
|
|
|
+ return t.chain[0], nil
|
|
|
+}
|
|
|
+
|
|
|
+func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
|
|
+ return t.chain, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (t *testIdentity) Signer() (crypto.Signer, error) {
|
|
|
+ return nil, errors.New("not implemented")
|
|
|
+}
|
|
|
+
|
|
|
+func (t *testIdentity) Delete() error {
|
|
|
+ return errors.New("not implemented")
|
|
|
+}
|
|
|
+
|
|
|
+func (t *testIdentity) Close() {}
|
|
|
+
|
|
|
+func TestSelectIdentityFromSlice(t *testing.T) {
|
|
|
+ var times []time.Time
|
|
|
+ for _, ts := range []string{
|
|
|
+ "2000-01-01T00:00:00Z",
|
|
|
+ "2001-01-01T00:00:00Z",
|
|
|
+ "2002-01-01T00:00:00Z",
|
|
|
+ "2003-01-01T00:00:00Z",
|
|
|
+ } {
|
|
|
+ tm, err := time.Parse(time.RFC3339, ts)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ times = append(times, tm)
|
|
|
+ }
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ subject string
|
|
|
+ ids []certstore.Identity
|
|
|
+ now time.Time
|
|
|
+ // wantIndex is an index into ids, or -1 for nil.
|
|
|
+ wantIndex int
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "single unexpired identity",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[2]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[1],
|
|
|
+ wantIndex: 0,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "single expired identity",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[2],
|
|
|
+ wantIndex: -1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "unrelated ids",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain("something", times[0], times[2]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[2]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain("else", times[0], times[2]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[1],
|
|
|
+ wantIndex: 1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "expired with unrelated ids",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain("something", times[0], times[3]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain("else", times[0], times[3]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[2],
|
|
|
+ wantIndex: -1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "one expired",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[2],
|
|
|
+ wantIndex: 1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "two certs both unexpired",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[3]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[2],
|
|
|
+ wantIndex: 1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "two unexpired one expired",
|
|
|
+ subject: testRootSubject,
|
|
|
+ ids: []certstore.Identity{
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[3]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
|
+ },
|
|
|
+ &testIdentity{
|
|
|
+ chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ now: times[2],
|
|
|
+ wantIndex: 1,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
|
|
|
+
|
|
|
+ if gotId == nil && gotChain != nil {
|
|
|
+ t.Error("id is nil: got non-nil chain, want nil chain")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if gotId != nil && gotChain == nil {
|
|
|
+ t.Error("id is not nil: got nil chain, want non-nil chain")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if tt.wantIndex == -1 {
|
|
|
+ if gotId != nil {
|
|
|
+ t.Error("got non-nil id, want nil id")
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if gotId == nil {
|
|
|
+ t.Error("got nil id, want non-nil id")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if gotId != tt.ids[tt.wantIndex] {
|
|
|
+ found := -1
|
|
|
+ for i := range tt.ids {
|
|
|
+ if tt.ids[i] == gotId {
|
|
|
+ found = i
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if found == -1 {
|
|
|
+ t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
|
|
|
+ } else {
|
|
|
+ t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
|
|
|
+ if !ok {
|
|
|
+ t.Error("got non-testIdentity, want testIdentity")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !reflect.DeepEqual(tid.chain, gotChain) {
|
|
|
+ t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|