sign_supported_test.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build windows && cgo
  4. package controlclient
  5. import (
  6. "crypto"
  7. "crypto/x509"
  8. "crypto/x509/pkix"
  9. "errors"
  10. "reflect"
  11. "testing"
  12. "time"
  13. "github.com/tailscale/certstore"
  14. )
  15. const (
  16. testRootCommonName = "testroot"
  17. testRootSubject = "CN=testroot"
  18. )
  19. type testIdentity struct {
  20. chain []*x509.Certificate
  21. }
  22. func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
  23. return []*x509.Certificate{
  24. {
  25. NotBefore: notBefore,
  26. NotAfter: notAfter,
  27. PublicKeyAlgorithm: x509.RSA,
  28. },
  29. {
  30. Subject: pkix.Name{
  31. CommonName: rootCommonName,
  32. },
  33. PublicKeyAlgorithm: x509.RSA,
  34. },
  35. }
  36. }
  37. func (t *testIdentity) Certificate() (*x509.Certificate, error) {
  38. return t.chain[0], nil
  39. }
  40. func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
  41. return t.chain, nil
  42. }
  43. func (t *testIdentity) Signer() (crypto.Signer, error) {
  44. return nil, errors.New("not implemented")
  45. }
  46. func (t *testIdentity) Delete() error {
  47. return errors.New("not implemented")
  48. }
  49. func (t *testIdentity) Close() {}
  50. func TestSelectIdentityFromSlice(t *testing.T) {
  51. var times []time.Time
  52. for _, ts := range []string{
  53. "2000-01-01T00:00:00Z",
  54. "2001-01-01T00:00:00Z",
  55. "2002-01-01T00:00:00Z",
  56. "2003-01-01T00:00:00Z",
  57. } {
  58. tm, err := time.Parse(time.RFC3339, ts)
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. times = append(times, tm)
  63. }
  64. tests := []struct {
  65. name string
  66. subject string
  67. ids []certstore.Identity
  68. now time.Time
  69. // wantIndex is an index into ids, or -1 for nil.
  70. wantIndex int
  71. }{
  72. {
  73. name: "single unexpired identity",
  74. subject: testRootSubject,
  75. ids: []certstore.Identity{
  76. &testIdentity{
  77. chain: makeChain(testRootCommonName, times[0], times[2]),
  78. },
  79. },
  80. now: times[1],
  81. wantIndex: 0,
  82. },
  83. {
  84. name: "single expired identity",
  85. subject: testRootSubject,
  86. ids: []certstore.Identity{
  87. &testIdentity{
  88. chain: makeChain(testRootCommonName, times[0], times[1]),
  89. },
  90. },
  91. now: times[2],
  92. wantIndex: -1,
  93. },
  94. {
  95. name: "unrelated ids",
  96. subject: testRootSubject,
  97. ids: []certstore.Identity{
  98. &testIdentity{
  99. chain: makeChain("something", times[0], times[2]),
  100. },
  101. &testIdentity{
  102. chain: makeChain(testRootCommonName, times[0], times[2]),
  103. },
  104. &testIdentity{
  105. chain: makeChain("else", times[0], times[2]),
  106. },
  107. },
  108. now: times[1],
  109. wantIndex: 1,
  110. },
  111. {
  112. name: "expired with unrelated ids",
  113. subject: testRootSubject,
  114. ids: []certstore.Identity{
  115. &testIdentity{
  116. chain: makeChain("something", times[0], times[3]),
  117. },
  118. &testIdentity{
  119. chain: makeChain(testRootCommonName, times[0], times[1]),
  120. },
  121. &testIdentity{
  122. chain: makeChain("else", times[0], times[3]),
  123. },
  124. },
  125. now: times[2],
  126. wantIndex: -1,
  127. },
  128. {
  129. name: "one expired",
  130. subject: testRootSubject,
  131. ids: []certstore.Identity{
  132. &testIdentity{
  133. chain: makeChain(testRootCommonName, times[0], times[1]),
  134. },
  135. &testIdentity{
  136. chain: makeChain(testRootCommonName, times[1], times[3]),
  137. },
  138. },
  139. now: times[2],
  140. wantIndex: 1,
  141. },
  142. {
  143. name: "two certs both unexpired",
  144. subject: testRootSubject,
  145. ids: []certstore.Identity{
  146. &testIdentity{
  147. chain: makeChain(testRootCommonName, times[0], times[3]),
  148. },
  149. &testIdentity{
  150. chain: makeChain(testRootCommonName, times[1], times[3]),
  151. },
  152. },
  153. now: times[2],
  154. wantIndex: 1,
  155. },
  156. {
  157. name: "two unexpired one expired",
  158. subject: testRootSubject,
  159. ids: []certstore.Identity{
  160. &testIdentity{
  161. chain: makeChain(testRootCommonName, times[0], times[3]),
  162. },
  163. &testIdentity{
  164. chain: makeChain(testRootCommonName, times[1], times[3]),
  165. },
  166. &testIdentity{
  167. chain: makeChain(testRootCommonName, times[0], times[1]),
  168. },
  169. },
  170. now: times[2],
  171. wantIndex: 1,
  172. },
  173. }
  174. for _, tt := range tests {
  175. t.Run(tt.name, func(t *testing.T) {
  176. gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
  177. if gotId == nil && gotChain != nil {
  178. t.Error("id is nil: got non-nil chain, want nil chain")
  179. return
  180. }
  181. if gotId != nil && gotChain == nil {
  182. t.Error("id is not nil: got nil chain, want non-nil chain")
  183. return
  184. }
  185. if tt.wantIndex == -1 {
  186. if gotId != nil {
  187. t.Error("got non-nil id, want nil id")
  188. }
  189. return
  190. }
  191. if gotId == nil {
  192. t.Error("got nil id, want non-nil id")
  193. return
  194. }
  195. if gotId != tt.ids[tt.wantIndex] {
  196. found := -1
  197. for i := range tt.ids {
  198. if tt.ids[i] == gotId {
  199. found = i
  200. break
  201. }
  202. }
  203. if found == -1 {
  204. t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
  205. } else {
  206. t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
  207. }
  208. }
  209. tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
  210. if !ok {
  211. t.Error("got non-testIdentity, want testIdentity")
  212. return
  213. }
  214. if !reflect.DeepEqual(tid.chain, gotChain) {
  215. t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
  216. }
  217. })
  218. }
  219. }