value_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package rate
  4. import (
  5. "flag"
  6. "math"
  7. "reflect"
  8. "testing"
  9. "time"
  10. qt "github.com/frankban/quicktest"
  11. "github.com/google/go-cmp/cmp/cmpopts"
  12. "tailscale.com/tstime/mono"
  13. "tailscale.com/util/must"
  14. )
  15. const (
  16. min = mono.Time(time.Minute)
  17. sec = mono.Time(time.Second)
  18. msec = mono.Time(time.Millisecond)
  19. usec = mono.Time(time.Microsecond)
  20. nsec = mono.Time(time.Nanosecond)
  21. val = 1.0e6
  22. )
  23. var longNumericalStabilityTest = flag.Bool("long-numerical-stability-test", false, "")
  24. func TestValue(t *testing.T) {
  25. // When performing many small calculations, the accuracy of the
  26. // result can drift due to accumulated errors in the calculation.
  27. // Verify that the result is correct even with many small updates.
  28. // See https://en.wikipedia.org/wiki/Numerical_stability.
  29. t.Run("NumericalStability", func(t *testing.T) {
  30. step := usec
  31. if *longNumericalStabilityTest {
  32. step = nsec
  33. }
  34. numStep := int(sec / step)
  35. c := qt.New(t)
  36. var v Value
  37. var now mono.Time
  38. for range numStep {
  39. v.addNow(now, float64(step))
  40. now += step
  41. }
  42. c.Assert(v.rateNow(now), qt.CmpEquals(cmpopts.EquateApprox(1e-6, 0)), 1e9/2)
  43. })
  44. halfLives := []struct {
  45. name string
  46. period time.Duration
  47. }{
  48. {"½s", time.Second / 2},
  49. {"1s", time.Second},
  50. {"2s", 2 * time.Second},
  51. }
  52. for _, halfLife := range halfLives {
  53. t.Run(halfLife.name+"/SpikeDecay", func(t *testing.T) {
  54. testValueSpikeDecay(t, halfLife.period, false)
  55. })
  56. t.Run(halfLife.name+"/SpikeDecayAddZero", func(t *testing.T) {
  57. testValueSpikeDecay(t, halfLife.period, true)
  58. })
  59. t.Run(halfLife.name+"/HighThenLow", func(t *testing.T) {
  60. testValueHighThenLow(t, halfLife.period)
  61. })
  62. t.Run(halfLife.name+"/LowFrequency", func(t *testing.T) {
  63. testLowFrequency(t, halfLife.period)
  64. })
  65. }
  66. }
  67. // testValueSpikeDecay starts with a target rate and ensure that it
  68. // exponentially decays according to the half-life formula.
  69. func testValueSpikeDecay(t *testing.T, halfLife time.Duration, addZero bool) {
  70. c := qt.New(t)
  71. v := Value{HalfLife: halfLife}
  72. v.addNow(0, val*v.normalizedIntegral())
  73. var now mono.Time
  74. var prevRate float64
  75. step := 100 * msec
  76. wantHalfRate := float64(val)
  77. for now < 10*sec {
  78. // Adding zero for every time-step will repeatedly trigger the
  79. // computation to decay the value, which may cause the result
  80. // to become more numerically unstable.
  81. if addZero {
  82. v.addNow(now, 0)
  83. }
  84. currRate := v.rateNow(now)
  85. t.Logf("%0.1fs:\t%0.3f", time.Duration(now).Seconds(), currRate)
  86. // At every multiple of a half-life period,
  87. // the current rate should be half the value of what
  88. // it was at the last half-life period.
  89. if time.Duration(now)%halfLife == 0 {
  90. c.Assert(currRate, qt.CmpEquals(cmpopts.EquateApprox(1e-12, 0)), wantHalfRate)
  91. wantHalfRate = currRate / 2
  92. }
  93. // Without any newly added events,
  94. // the rate should be decaying over time.
  95. if now > 0 && prevRate < currRate {
  96. t.Errorf("%v: rate is not decaying: %0.1f < %0.1f", time.Duration(now), prevRate, currRate)
  97. }
  98. if currRate < 0 {
  99. t.Errorf("%v: rate too low: %0.1f < %0.1f", time.Duration(now), currRate, 0.0)
  100. }
  101. prevRate = currRate
  102. now += step
  103. }
  104. }
  105. // testValueHighThenLow targets a steady-state rate that is high,
  106. // then switches to a target steady-state rate that is low.
  107. func testValueHighThenLow(t *testing.T, halfLife time.Duration) {
  108. c := qt.New(t)
  109. v := Value{HalfLife: halfLife}
  110. var now mono.Time
  111. var prevRate float64
  112. var wantRate float64
  113. const step = 10 * msec
  114. const stepsPerSecond = int(sec / step)
  115. // Target a higher steady-state rate.
  116. wantRate = 2 * val
  117. wantHalfRate := float64(0.0)
  118. eventsPerStep := wantRate / float64(stepsPerSecond)
  119. for now < 10*sec {
  120. currRate := v.rateNow(now)
  121. v.addNow(now, eventsPerStep)
  122. t.Logf("%0.1fs:\t%0.3f", time.Duration(now).Seconds(), currRate)
  123. // At every multiple of a half-life period,
  124. // the current rate should be half-way more towards
  125. // the target rate relative to before.
  126. if time.Duration(now)%halfLife == 0 {
  127. c.Assert(currRate, qt.CmpEquals(cmpopts.EquateApprox(0.1, 0)), wantHalfRate)
  128. wantHalfRate += (wantRate - currRate) / 2
  129. }
  130. // Rate should approach wantRate from below,
  131. // but never exceed it.
  132. if now > 0 && prevRate > currRate {
  133. t.Errorf("%v: rate is not growing: %0.1f > %0.1f", time.Duration(now), prevRate, currRate)
  134. }
  135. if currRate > 1.01*wantRate {
  136. t.Errorf("%v: rate too high: %0.1f > %0.1f", time.Duration(now), currRate, wantRate)
  137. }
  138. prevRate = currRate
  139. now += step
  140. }
  141. c.Assert(prevRate, qt.CmpEquals(cmpopts.EquateApprox(0.05, 0)), wantRate)
  142. // Target a lower steady-state rate.
  143. wantRate = val / 3
  144. wantHalfRate = prevRate
  145. eventsPerStep = wantRate / float64(stepsPerSecond)
  146. for now < 20*sec {
  147. currRate := v.rateNow(now)
  148. v.addNow(now, eventsPerStep)
  149. t.Logf("%0.1fs:\t%0.3f", time.Duration(now).Seconds(), currRate)
  150. // At every multiple of a half-life period,
  151. // the current rate should be half-way more towards
  152. // the target rate relative to before.
  153. if time.Duration(now)%halfLife == 0 {
  154. c.Assert(currRate, qt.CmpEquals(cmpopts.EquateApprox(0.1, 0)), wantHalfRate)
  155. wantHalfRate += (wantRate - currRate) / 2
  156. }
  157. // Rate should approach wantRate from above,
  158. // but never exceed it.
  159. if now > 10*sec && prevRate < currRate {
  160. t.Errorf("%v: rate is not decaying: %0.1f < %0.1f", time.Duration(now), prevRate, currRate)
  161. }
  162. if currRate < 0.99*wantRate {
  163. t.Errorf("%v: rate too low: %0.1f < %0.1f", time.Duration(now), currRate, wantRate)
  164. }
  165. prevRate = currRate
  166. now += step
  167. }
  168. c.Assert(prevRate, qt.CmpEquals(cmpopts.EquateApprox(0.15, 0)), wantRate)
  169. }
  170. // testLowFrequency fires an event at a frequency much slower than
  171. // the specified half-life period. While the average rate over time
  172. // should be accurate, the standard deviation gets worse.
  173. func testLowFrequency(t *testing.T, halfLife time.Duration) {
  174. v := Value{HalfLife: halfLife}
  175. var now mono.Time
  176. var rates []float64
  177. for now < 20*min {
  178. if now%(10*sec) == 0 {
  179. v.addNow(now, 1) // 1 event every 10 seconds
  180. }
  181. now += 50 * msec
  182. rates = append(rates, v.rateNow(now))
  183. now += 50 * msec
  184. }
  185. mean, stddev := stats(rates)
  186. c := qt.New(t)
  187. c.Assert(mean, qt.CmpEquals(cmpopts.EquateApprox(0.001, 0)), 0.1)
  188. t.Logf("mean:%v stddev:%v", mean, stddev)
  189. }
  190. func stats(fs []float64) (mean, stddev float64) {
  191. for _, rate := range fs {
  192. mean += rate
  193. }
  194. mean /= float64(len(fs))
  195. for _, rate := range fs {
  196. stddev += (rate - mean) * (rate - mean)
  197. }
  198. stddev = math.Sqrt(stddev / float64(len(fs)))
  199. return mean, stddev
  200. }
  201. // BenchmarkValue benchmarks the cost of Value.Add,
  202. // which is called often and makes extensive use of floating-point math.
  203. func BenchmarkValue(b *testing.B) {
  204. b.ReportAllocs()
  205. v := Value{HalfLife: time.Second}
  206. for range b.N {
  207. v.Add(1)
  208. }
  209. }
  210. func TestValueMarshal(t *testing.T) {
  211. now := mono.Now()
  212. tests := []struct {
  213. val *Value
  214. str string
  215. }{
  216. {val: &Value{}, str: `{}`},
  217. {val: &Value{HalfLife: 5 * time.Minute}, str: `{"halfLife":"` + (5 * time.Minute).String() + `"}`},
  218. {val: &Value{value: 12345, updated: now}, str: `{"value":12345,"updated":` + string(must.Get(now.MarshalJSON())) + `}`},
  219. }
  220. for _, tt := range tests {
  221. str := string(must.Get(tt.val.MarshalJSON()))
  222. if str != tt.str {
  223. t.Errorf("string mismatch: got %v, want %v", str, tt.str)
  224. }
  225. var val Value
  226. must.Do(val.UnmarshalJSON([]byte(str)))
  227. if !reflect.DeepEqual(&val, tt.val) {
  228. t.Errorf("value mismatch: %+v, want %+v", &val, tt.val)
  229. }
  230. }
  231. }