mem_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package reqlimit_test
  2. import (
  3. "fmt"
  4. "sync"
  5. "testing"
  6. "time"
  7. "github.com/labring/aiproxy/core/common/reqlimit"
  8. )
  9. func TestNewInMemoryRateLimiter(t *testing.T) {
  10. rl := reqlimit.NewInMemoryRecord()
  11. if rl == nil {
  12. t.Fatal("NewInMemoryRateLimiter should return a non-nil instance")
  13. }
  14. }
  15. func TestPushRequestBasic(t *testing.T) {
  16. rl := reqlimit.NewInMemoryRecord()
  17. normalCount, overCount, secondCount := rl.PushRequest(10, 60*time.Second, 1, "group1", "model1")
  18. if normalCount != 1 {
  19. t.Errorf("Expected normalCount to be 1, got %d", normalCount)
  20. }
  21. if overCount != 0 {
  22. t.Errorf("Expected overCount to be 0, got %d", overCount)
  23. }
  24. if secondCount != 1 {
  25. t.Errorf("Expected secondCount to be 1, got %d", secondCount)
  26. }
  27. }
  28. func TestPushRequestRateLimit(t *testing.T) {
  29. rl := reqlimit.NewInMemoryRecord()
  30. maxReq := int64(2)
  31. duration := 60 * time.Second
  32. for i := range 4 {
  33. normalCount, overCount, _ := rl.PushRequest(maxReq, duration, 1, "group1", "model1")
  34. switch {
  35. case i < 2:
  36. if normalCount != int64(i+1) {
  37. t.Errorf("Request %d: expected normalCount %d, got %d", i+1, i+1, normalCount)
  38. }
  39. if overCount != 0 {
  40. t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
  41. }
  42. case i == 2:
  43. if normalCount != 3 {
  44. t.Errorf("Request %d: expected normalCount 3, got %d", i+1, normalCount)
  45. }
  46. if overCount != 0 {
  47. t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
  48. }
  49. case i == 3:
  50. if normalCount != 3 {
  51. t.Errorf("Request %d: expected normalCount 3, got %d", i+1, normalCount)
  52. }
  53. if overCount != 1 {
  54. t.Errorf("Request %d: expected overCount 1, got %d", i+1, overCount)
  55. }
  56. }
  57. }
  58. }
  59. func TestPushRequestUnlimited(t *testing.T) {
  60. rl := reqlimit.NewInMemoryRecord()
  61. for i := range 5 {
  62. normalCount, overCount, _ := rl.PushRequest(0, 60*time.Second, 1, "group1", "model1")
  63. if normalCount != int64(i+1) {
  64. t.Errorf("Request %d: expected normalCount %d, got %d", i+1, i+1, normalCount)
  65. }
  66. if overCount != 0 {
  67. t.Errorf("Request %d: expected overCount 0, got %d", i+1, overCount)
  68. }
  69. }
  70. }
  71. func TestGetRequest(t *testing.T) {
  72. rl := reqlimit.NewInMemoryRecord()
  73. rl.PushRequest(10, 60*time.Second, 1, "group1", "model1")
  74. rl.PushRequest(10, 60*time.Second, 1, "group1", "model2")
  75. rl.PushRequest(10, 60*time.Second, 1, "group2", "model1")
  76. totalCount, secondCount := rl.GetRequest(60*time.Second, "group1", "model1")
  77. if totalCount != 1 {
  78. t.Errorf("Expected totalCount 1, got %d", totalCount)
  79. }
  80. if secondCount != 1 {
  81. t.Errorf("Expected secondCount 1, got %d", secondCount)
  82. }
  83. totalCount, _ = rl.GetRequest(60*time.Second, "*", "*")
  84. if totalCount != 3 {
  85. t.Errorf("Expected totalCount 3 for wildcard query, got %d", totalCount)
  86. }
  87. totalCount, _ = rl.GetRequest(60*time.Second, "group1", "*")
  88. if totalCount != 2 {
  89. t.Errorf("Expected totalCount 2 for group1 wildcard, got %d", totalCount)
  90. }
  91. }
  92. func TestMultipleGroupsAndModels(t *testing.T) {
  93. rl := reqlimit.NewInMemoryRecord()
  94. groups := []string{"group1", "group2", "group3"}
  95. models := []string{"model1", "model2"}
  96. for _, group := range groups {
  97. for _, model := range models {
  98. rl.PushRequest(10, 60*time.Second, 1, group, model)
  99. }
  100. }
  101. totalCount, _ := rl.GetRequest(60*time.Second, "*", "*")
  102. expected := len(groups) * len(models)
  103. if totalCount != int64(expected) {
  104. t.Errorf("Expected totalCount %d, got %d", expected, totalCount)
  105. }
  106. }
  107. func TestTimeWindowCleanup(t *testing.T) {
  108. rl := reqlimit.NewInMemoryRecord()
  109. rl.PushRequest(10, 2*time.Second, 1, "group1", "model1")
  110. totalCount, _ := rl.GetRequest(2*time.Second, "group1", "model1")
  111. if totalCount != 1 {
  112. t.Errorf("Expected totalCount 1, got %d", totalCount)
  113. }
  114. time.Sleep(3 * time.Second)
  115. totalCount, _ = rl.GetRequest(2*time.Second, "group1", "model1")
  116. if totalCount != 0 {
  117. t.Errorf("Expected totalCount 0 after cleanup, got %d", totalCount)
  118. }
  119. }
  120. func TestConcurrentAccess(t *testing.T) {
  121. rl := reqlimit.NewInMemoryRecord()
  122. const (
  123. numGoroutines = 100
  124. requestsPerGoroutine = 10
  125. )
  126. var wg sync.WaitGroup
  127. wg.Add(numGoroutines)
  128. for i := range numGoroutines {
  129. go func(_ int) {
  130. defer wg.Done()
  131. for range requestsPerGoroutine {
  132. rl.PushRequest(0, 60*time.Second, 1, "group1", "model1")
  133. }
  134. }(i)
  135. }
  136. wg.Wait()
  137. totalCount, _ := rl.GetRequest(60*time.Second, "group1", "model1")
  138. expected := int64(numGoroutines * requestsPerGoroutine)
  139. if totalCount != expected {
  140. t.Errorf("Expected totalCount %d, got %d", expected, totalCount)
  141. }
  142. }
  143. func TestConcurrentDifferentKeys(t *testing.T) {
  144. rl := reqlimit.NewInMemoryRecord()
  145. const numGoroutines = 50
  146. var wg sync.WaitGroup
  147. wg.Add(numGoroutines)
  148. for i := range numGoroutines {
  149. go func(id int) {
  150. defer wg.Done()
  151. group := fmt.Sprintf("group%d", id%5)
  152. model := fmt.Sprintf("model%d", id%3)
  153. rl.PushRequest(10, 60*time.Second, 1, group, model)
  154. }(i)
  155. }
  156. wg.Wait()
  157. // 验证总数
  158. totalCount, _ := rl.GetRequest(60*time.Second, "*", "*")
  159. if totalCount != int64(numGoroutines) {
  160. t.Errorf("Expected totalCount %d, got %d", numGoroutines, totalCount)
  161. }
  162. }
  163. func TestRateLimitWithOverflow(t *testing.T) {
  164. rl := reqlimit.NewInMemoryRecord()
  165. maxReq := 5
  166. duration := 60 * time.Second
  167. for i := range 10 {
  168. normalCount, overCount, _ := rl.PushRequest(int64(maxReq), duration, 1, "group1", "model1")
  169. if i < maxReq {
  170. if normalCount != int64(i+1) || overCount != 0 {
  171. t.Errorf("Request %d: expected normal=%d, over=0, got normal=%d, over=%d",
  172. i+1, i+1, normalCount, overCount)
  173. }
  174. } else {
  175. expectedOver := int64(i - maxReq)
  176. if normalCount != int64(maxReq+1) || overCount != expectedOver {
  177. t.Errorf("Request %d: expected normal=5, over=%d, got normal=%d, over=%d",
  178. i+1, expectedOver, normalCount, overCount)
  179. }
  180. }
  181. }
  182. }
  183. func TestEmptyQueries(t *testing.T) {
  184. rl := reqlimit.NewInMemoryRecord()
  185. totalCount, secondCount := rl.GetRequest(60*time.Second, "*", "*")
  186. if totalCount != 0 || secondCount != 0 {
  187. t.Errorf("Expected empty results, got total=%d, second=%d", totalCount, secondCount)
  188. }
  189. totalCount, secondCount = rl.GetRequest(60*time.Second, "nonexistent", "model")
  190. if totalCount != 0 || secondCount != 0 {
  191. t.Errorf(
  192. "Expected empty results for nonexistent key, got total=%d, second=%d",
  193. totalCount,
  194. secondCount,
  195. )
  196. }
  197. }
  198. func BenchmarkPushRequest(b *testing.B) {
  199. rl := reqlimit.NewInMemoryRecord()
  200. b.ResetTimer()
  201. b.RunParallel(func(pb *testing.PB) {
  202. i := 0
  203. for pb.Next() {
  204. group := fmt.Sprintf("group%d", i%10)
  205. model := fmt.Sprintf("model%d", i%5)
  206. rl.PushRequest(100, 60*time.Second, 1, group, model)
  207. i++
  208. }
  209. })
  210. }
  211. func BenchmarkGetRequest(b *testing.B) {
  212. rl := reqlimit.NewInMemoryRecord()
  213. for i := range 100 {
  214. group := fmt.Sprintf("group%d", i%10)
  215. model := fmt.Sprintf("model%d", i%5)
  216. rl.PushRequest(100, 60*time.Second, 1, group, model)
  217. }
  218. b.ResetTimer()
  219. b.RunParallel(func(pb *testing.PB) {
  220. i := 0
  221. for pb.Next() {
  222. group := fmt.Sprintf("group%d", i%10)
  223. model := fmt.Sprintf("model%d", i%5)
  224. rl.GetRequest(60*time.Second, group, model)
  225. i++
  226. }
  227. })
  228. }