state_test.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build !ts_omit_tailnetlock
  4. package tka
  5. import (
  6. "bytes"
  7. "encoding/hex"
  8. "errors"
  9. "testing"
  10. "github.com/fxamacker/cbor/v2"
  11. "github.com/google/go-cmp/cmp"
  12. "github.com/google/go-cmp/cmp/cmpopts"
  13. )
  14. func fromHex(in string) []byte {
  15. out, err := hex.DecodeString(in)
  16. if err != nil {
  17. panic(err)
  18. }
  19. return out
  20. }
  21. func hashFromHex(in string) *AUMHash {
  22. var out AUMHash
  23. copy(out[:], fromHex(in))
  24. return &out
  25. }
  26. func TestCloneState(t *testing.T) {
  27. tcs := []struct {
  28. Name string
  29. State State
  30. }{
  31. {
  32. "Empty",
  33. State{},
  34. },
  35. {
  36. "Key",
  37. State{
  38. Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}},
  39. },
  40. },
  41. {
  42. "StateID",
  43. State{
  44. StateID1: 42,
  45. StateID2: 22,
  46. },
  47. },
  48. {
  49. "DisablementSecrets",
  50. State{
  51. DisablementSecrets: [][]byte{
  52. {1, 2, 3, 4},
  53. {5, 6, 7, 8},
  54. },
  55. },
  56. },
  57. }
  58. for _, tc := range tcs {
  59. t.Run(tc.Name, func(t *testing.T) {
  60. if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" {
  61. t.Errorf("output state differs (-want, +got):\n%s", diff)
  62. }
  63. // Make sure the cloned State is the same even after
  64. // an encode + decode into + from CBOR.
  65. t.Run("cbor", func(t *testing.T) {
  66. out := bytes.NewBuffer(nil)
  67. encoder, err := cbor.CTAP2EncOptions().EncMode()
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil {
  72. t.Fatal(err)
  73. }
  74. var decodedState State
  75. if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil {
  76. t.Fatalf("Unmarshal failed: %v", err)
  77. }
  78. if diff := cmp.Diff(tc.State, decodedState); diff != "" {
  79. t.Errorf("decoded state differs (-want, +got):\n%s", diff)
  80. }
  81. })
  82. })
  83. }
  84. }
  85. func TestApplyUpdatesChain(t *testing.T) {
  86. intOne := uint(1)
  87. tcs := []struct {
  88. Name string
  89. Updates []AUM
  90. Start State
  91. End State
  92. }{
  93. {
  94. "AddKey",
  95. []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}},
  96. State{},
  97. State{
  98. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  99. LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"),
  100. },
  101. },
  102. {
  103. "RemoveKey",
  104. []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}},
  105. State{
  106. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  107. LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"),
  108. },
  109. State{
  110. LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"),
  111. },
  112. },
  113. {
  114. "UpdateKey",
  115. []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}},
  116. State{
  117. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  118. LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"),
  119. },
  120. State{
  121. LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"),
  122. Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}},
  123. },
  124. },
  125. {
  126. "ChainedKeyUpdates",
  127. []AUM{
  128. {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}},
  129. {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")},
  130. },
  131. State{
  132. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  133. },
  134. State{
  135. Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}},
  136. LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"),
  137. },
  138. },
  139. {
  140. "Checkpoint",
  141. []AUM{
  142. {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}},
  143. {MessageKind: AUMCheckpoint, State: &State{
  144. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  145. }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")},
  146. },
  147. State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}},
  148. State{
  149. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  150. LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"),
  151. },
  152. },
  153. }
  154. for _, tc := range tcs {
  155. t.Run(tc.Name, func(t *testing.T) {
  156. state := tc.Start
  157. for i := range tc.Updates {
  158. var err error
  159. // t.Logf("update[%d] start-state = %+v", i, state)
  160. state, err = state.applyVerifiedAUM(tc.Updates[i])
  161. if err != nil {
  162. t.Fatalf("Apply message[%d] failed: %v", i, err)
  163. }
  164. // t.Logf("update[%d] end-state = %+v", i, state)
  165. updateHash := tc.Updates[i].Hash()
  166. if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) {
  167. t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got)
  168. }
  169. }
  170. if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" {
  171. t.Errorf("output state differs (+got, -want):\n%s", diff)
  172. }
  173. })
  174. }
  175. }
  176. func TestApplyUpdateErrors(t *testing.T) {
  177. tooLargeVotes := uint(99999)
  178. tcs := []struct {
  179. Name string
  180. Updates []AUM
  181. Start State
  182. Error error
  183. }{
  184. {
  185. "AddKey exists",
  186. []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}},
  187. State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}},
  188. errors.New("key already exists"),
  189. },
  190. {
  191. "RemoveKey notfound",
  192. []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}},
  193. State{},
  194. ErrNoSuchKey,
  195. },
  196. {
  197. "UpdateKey notfound",
  198. []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}},
  199. State{},
  200. ErrNoSuchKey,
  201. },
  202. {
  203. "UpdateKey now fails validation",
  204. []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}},
  205. State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}},
  206. errors.New("updated key fails validation: excessive key weight: 99999 > 4096"),
  207. },
  208. {
  209. "Bad lastAUMHash",
  210. []AUM{
  211. {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}},
  212. {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")},
  213. },
  214. State{
  215. Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}},
  216. },
  217. errors.New("parent AUMHash mismatch"),
  218. },
  219. {
  220. "Bad StateID",
  221. []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}},
  222. State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42},
  223. errors.New("checkpointed state has an incorrect stateID"),
  224. },
  225. }
  226. for _, tc := range tcs {
  227. t.Run(tc.Name, func(t *testing.T) {
  228. state := tc.Start
  229. for i := range tc.Updates {
  230. var err error
  231. // t.Logf("update[%d] start-state = %+v", i, state)
  232. state, err = state.applyVerifiedAUM(tc.Updates[i])
  233. if err != nil {
  234. if err.Error() != tc.Error.Error() {
  235. t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error)
  236. } else {
  237. return
  238. }
  239. }
  240. // t.Logf("update[%d] end-state = %+v", i, state)
  241. }
  242. t.Errorf("did not error, expected %v", tc.Error)
  243. })
  244. }
  245. }