key_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package ctxkey
  4. import (
  5. "context"
  6. "fmt"
  7. "io"
  8. "regexp"
  9. "testing"
  10. "time"
  11. qt "github.com/frankban/quicktest"
  12. )
  13. func TestKey(t *testing.T) {
  14. c := qt.New(t)
  15. ctx := context.Background()
  16. // Test keys with the same name as being distinct.
  17. k1 := New("same.Name", "")
  18. c.Assert(k1.String(), qt.Equals, "same.Name")
  19. k2 := New("same.Name", "")
  20. c.Assert(k2.String(), qt.Equals, "same.Name")
  21. c.Assert(k1 == k2, qt.Equals, false)
  22. ctx = k1.WithValue(ctx, "hello")
  23. c.Assert(k1.Has(ctx), qt.Equals, true)
  24. c.Assert(k1.Value(ctx), qt.Equals, "hello")
  25. c.Assert(k2.Has(ctx), qt.Equals, false)
  26. c.Assert(k2.Value(ctx), qt.Equals, "")
  27. ctx = k2.WithValue(ctx, "goodbye")
  28. c.Assert(k1.Has(ctx), qt.Equals, true)
  29. c.Assert(k1.Value(ctx), qt.Equals, "hello")
  30. c.Assert(k2.Has(ctx), qt.Equals, true)
  31. c.Assert(k2.Value(ctx), qt.Equals, "goodbye")
  32. // Test default value.
  33. k3 := New("mapreduce.Timeout", time.Hour)
  34. c.Assert(k3.Has(ctx), qt.Equals, false)
  35. c.Assert(k3.Value(ctx), qt.Equals, time.Hour)
  36. ctx = k3.WithValue(ctx, time.Minute)
  37. c.Assert(k3.Has(ctx), qt.Equals, true)
  38. c.Assert(k3.Value(ctx), qt.Equals, time.Minute)
  39. // Test incomparable value.
  40. k4 := New("slice", []int(nil))
  41. c.Assert(k4.Has(ctx), qt.Equals, false)
  42. c.Assert(k4.Value(ctx), qt.DeepEquals, []int(nil))
  43. ctx = k4.WithValue(ctx, []int{1, 2, 3})
  44. c.Assert(k4.Has(ctx), qt.Equals, true)
  45. c.Assert(k4.Value(ctx), qt.DeepEquals, []int{1, 2, 3})
  46. // Accessors should be allocation free.
  47. c.Assert(testing.AllocsPerRun(100, func() {
  48. k1.Value(ctx)
  49. k1.Has(ctx)
  50. k1.ValueOk(ctx)
  51. }), qt.Equals, 0.0)
  52. // Test keys that are created without New.
  53. var k5 Key[string]
  54. c.Assert(k5.String(), qt.Equals, "string")
  55. c.Assert(k1 == k5, qt.Equals, false) // should be different from key created by New
  56. c.Assert(k5.Has(ctx), qt.Equals, false)
  57. ctx = k5.WithValue(ctx, "fizz")
  58. c.Assert(k5.Value(ctx), qt.Equals, "fizz")
  59. var k6 Key[string]
  60. c.Assert(k6.String(), qt.Equals, "string")
  61. c.Assert(k5 == k6, qt.Equals, true)
  62. c.Assert(k6.Has(ctx), qt.Equals, true)
  63. ctx = k6.WithValue(ctx, "fizz")
  64. // Test interface value types.
  65. var k7 Key[any]
  66. c.Assert(k7.Has(ctx), qt.Equals, false)
  67. ctx = k7.WithValue(ctx, "whatever")
  68. c.Assert(k7.Value(ctx), qt.DeepEquals, "whatever")
  69. ctx = k7.WithValue(ctx, []int{1, 2, 3})
  70. c.Assert(k7.Value(ctx), qt.DeepEquals, []int{1, 2, 3})
  71. ctx = k7.WithValue(ctx, nil)
  72. c.Assert(k7.Has(ctx), qt.Equals, true)
  73. c.Assert(k7.Value(ctx), qt.DeepEquals, nil)
  74. k8 := New[error]("error", io.EOF)
  75. c.Assert(k8.Has(ctx), qt.Equals, false)
  76. c.Assert(k8.Value(ctx), qt.Equals, io.EOF)
  77. ctx = k8.WithValue(ctx, nil)
  78. c.Assert(k8.Value(ctx), qt.Equals, nil)
  79. c.Assert(k8.Has(ctx), qt.Equals, true)
  80. err := fmt.Errorf("read error: %w", io.ErrUnexpectedEOF)
  81. ctx = k8.WithValue(ctx, err)
  82. c.Assert(k8.Value(ctx), qt.Equals, err)
  83. c.Assert(k8.Has(ctx), qt.Equals, true)
  84. }
  85. func TestStringer(t *testing.T) {
  86. t.SkipNow() // TODO(https://go.dev/cl/555697): Enable this after fix is merged upstream.
  87. c := qt.New(t)
  88. ctx := context.Background()
  89. c.Assert(fmt.Sprint(New("foo.Bar", "").WithValue(ctx, "baz")), qt.Matches, regexp.MustCompile("foo.Bar.*baz"))
  90. c.Assert(fmt.Sprint(New("", []int{}).WithValue(ctx, []int{1, 2, 3})), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", []int{1, 2, 3})))
  91. c.Assert(fmt.Sprint(New("", 0).WithValue(ctx, 5)), qt.Matches, regexp.MustCompile("int.*5"))
  92. c.Assert(fmt.Sprint(Key[time.Duration]{}.WithValue(ctx, time.Hour)), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", time.Hour)))
  93. }