store_test.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package ipn
  4. import (
  5. "bytes"
  6. "iter"
  7. "sync"
  8. "testing"
  9. "tailscale.com/util/mak"
  10. )
  11. type memStore struct {
  12. mu sync.Mutex
  13. writes int
  14. m map[StateKey][]byte
  15. }
  16. func (s *memStore) ReadState(k StateKey) ([]byte, error) {
  17. s.mu.Lock()
  18. defer s.mu.Unlock()
  19. return bytes.Clone(s.m[k]), nil
  20. }
  21. func (s *memStore) WriteState(k StateKey, v []byte) error {
  22. s.mu.Lock()
  23. defer s.mu.Unlock()
  24. mak.Set(&s.m, k, bytes.Clone(v))
  25. s.writes++
  26. return nil
  27. }
  28. func (s *memStore) All() iter.Seq2[StateKey, []byte] {
  29. return func(yield func(StateKey, []byte) bool) {
  30. s.mu.Lock()
  31. defer s.mu.Unlock()
  32. for k, v := range s.m {
  33. if !yield(k, v) {
  34. break
  35. }
  36. }
  37. }
  38. }
  39. func TestWriteState(t *testing.T) {
  40. var ss StateStore = new(memStore)
  41. WriteState(ss, "foo", []byte("bar"))
  42. WriteState(ss, "foo", []byte("bar"))
  43. got, err := ss.ReadState("foo")
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if want := []byte("bar"); !bytes.Equal(got, want) {
  48. t.Errorf("got %q; want %q", got, want)
  49. }
  50. if got, want := ss.(*memStore).writes, 1; got != want {
  51. t.Errorf("got %d writes; want %d", got, want)
  52. }
  53. }