channel_affinity_template_test.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package service
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "strings"
  7. "testing"
  8. "time"
  9. relaycommon "github.com/QuantumNous/new-api/relay/common"
  10. "github.com/QuantumNous/new-api/setting/operation_setting"
  11. "github.com/gin-gonic/gin"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func buildChannelAffinityTemplateContextForTest(meta channelAffinityMeta) *gin.Context {
  15. rec := httptest.NewRecorder()
  16. ctx, _ := gin.CreateTestContext(rec)
  17. setChannelAffinityContext(ctx, meta)
  18. return ctx
  19. }
  20. func TestApplyChannelAffinityOverrideTemplate_NoTemplate(t *testing.T) {
  21. ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  22. RuleName: "rule-no-template",
  23. })
  24. base := map[string]interface{}{
  25. "temperature": 0.7,
  26. }
  27. merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
  28. require.False(t, applied)
  29. require.Equal(t, base, merged)
  30. }
  31. func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
  32. ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  33. RuleName: "rule-with-template",
  34. ParamTemplate: map[string]interface{}{
  35. "temperature": 0.2,
  36. "top_p": 0.95,
  37. },
  38. UsingGroup: "default",
  39. ModelName: "gpt-4.1",
  40. RequestPath: "/v1/responses",
  41. KeySourceType: "gjson",
  42. KeySourcePath: "prompt_cache_key",
  43. KeyHint: "abcd...wxyz",
  44. KeyFingerprint: "abcd1234",
  45. })
  46. base := map[string]interface{}{
  47. "temperature": 0.7,
  48. "max_tokens": 2000,
  49. }
  50. merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
  51. require.True(t, applied)
  52. require.Equal(t, 0.7, merged["temperature"])
  53. require.Equal(t, 0.95, merged["top_p"])
  54. require.Equal(t, 2000, merged["max_tokens"])
  55. require.Equal(t, 0.7, base["temperature"])
  56. anyInfo, ok := ctx.Get(ginKeyChannelAffinityLogInfo)
  57. require.True(t, ok)
  58. info, ok := anyInfo.(map[string]interface{})
  59. require.True(t, ok)
  60. overrideInfoAny, ok := info["override_template"]
  61. require.True(t, ok)
  62. overrideInfo, ok := overrideInfoAny.(map[string]interface{})
  63. require.True(t, ok)
  64. require.Equal(t, true, overrideInfo["applied"])
  65. require.Equal(t, "rule-with-template", overrideInfo["rule_name"])
  66. require.EqualValues(t, 2, overrideInfo["param_override_keys"])
  67. }
  68. func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
  69. ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  70. RuleName: "rule-with-ops-template",
  71. ParamTemplate: map[string]interface{}{
  72. "operations": []map[string]interface{}{
  73. {
  74. "mode": "pass_headers",
  75. "value": []string{"Originator"},
  76. },
  77. },
  78. },
  79. })
  80. base := map[string]interface{}{
  81. "temperature": 0.7,
  82. "operations": []map[string]interface{}{
  83. {
  84. "path": "model",
  85. "mode": "trim_prefix",
  86. "value": "openai/",
  87. },
  88. },
  89. }
  90. merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
  91. require.True(t, applied)
  92. require.Equal(t, 0.7, merged["temperature"])
  93. opsAny, ok := merged["operations"]
  94. require.True(t, ok)
  95. ops, ok := opsAny.([]interface{})
  96. require.True(t, ok)
  97. require.Len(t, ops, 2)
  98. firstOp, ok := ops[0].(map[string]interface{})
  99. require.True(t, ok)
  100. require.Equal(t, "pass_headers", firstOp["mode"])
  101. secondOp, ok := ops[1].(map[string]interface{})
  102. require.True(t, ok)
  103. require.Equal(t, "trim_prefix", secondOp["mode"])
  104. }
  105. func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) {
  106. tests := []struct {
  107. name string
  108. ctx func() *gin.Context
  109. want bool
  110. }{
  111. {
  112. name: "nil context",
  113. ctx: func() *gin.Context {
  114. return nil
  115. },
  116. want: false,
  117. },
  118. {
  119. name: "explicit skip retry flag in context",
  120. ctx: func() *gin.Context {
  121. ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  122. RuleName: "rule-explicit-flag",
  123. SkipRetry: false,
  124. UsingGroup: "default",
  125. ModelName: "gpt-5",
  126. })
  127. ctx.Set(ginKeyChannelAffinitySkipRetry, true)
  128. return ctx
  129. },
  130. want: true,
  131. },
  132. {
  133. name: "fallback to matched rule meta",
  134. ctx: func() *gin.Context {
  135. return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  136. RuleName: "rule-skip-retry",
  137. SkipRetry: true,
  138. UsingGroup: "default",
  139. ModelName: "gpt-5",
  140. })
  141. },
  142. want: true,
  143. },
  144. {
  145. name: "no flag and no skip retry meta",
  146. ctx: func() *gin.Context {
  147. return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
  148. RuleName: "rule-no-skip-retry",
  149. SkipRetry: false,
  150. UsingGroup: "default",
  151. ModelName: "gpt-5",
  152. })
  153. },
  154. want: false,
  155. },
  156. }
  157. for _, tt := range tests {
  158. t.Run(tt.name, func(t *testing.T) {
  159. require.Equal(t, tt.want, ShouldSkipRetryAfterChannelAffinityFailure(tt.ctx()))
  160. })
  161. }
  162. }
  163. func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
  164. gin.SetMode(gin.TestMode)
  165. setting := operation_setting.GetChannelAffinitySetting()
  166. require.NotNil(t, setting)
  167. var codexRule *operation_setting.ChannelAffinityRule
  168. for i := range setting.Rules {
  169. rule := &setting.Rules[i]
  170. if strings.EqualFold(strings.TrimSpace(rule.Name), "codex cli trace") {
  171. codexRule = rule
  172. break
  173. }
  174. }
  175. require.NotNil(t, codexRule)
  176. affinityValue := fmt.Sprintf("pc-hit-%d", time.Now().UnixNano())
  177. cacheKeySuffix := buildChannelAffinityCacheKeySuffix(*codexRule, "gpt-5", "default", affinityValue)
  178. cache := getChannelAffinityCache()
  179. require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute))
  180. t.Cleanup(func() {
  181. _, _ = cache.DeleteMany([]string{cacheKeySuffix})
  182. })
  183. rec := httptest.NewRecorder()
  184. ctx, _ := gin.CreateTestContext(rec)
  185. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(fmt.Sprintf(`{"prompt_cache_key":"%s"}`, affinityValue)))
  186. ctx.Request.Header.Set("Content-Type", "application/json")
  187. channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default")
  188. require.True(t, found)
  189. require.Equal(t, 9527, channelID)
  190. baseOverride := map[string]interface{}{
  191. "temperature": 0.2,
  192. }
  193. mergedOverride, applied := ApplyChannelAffinityOverrideTemplate(ctx, baseOverride)
  194. require.True(t, applied)
  195. require.Equal(t, 0.2, mergedOverride["temperature"])
  196. info := &relaycommon.RelayInfo{
  197. RequestHeaders: map[string]string{
  198. "Originator": "Codex CLI",
  199. "Session_id": "sess-123",
  200. "User-Agent": "codex-cli-test",
  201. },
  202. ChannelMeta: &relaycommon.ChannelMeta{
  203. ParamOverride: mergedOverride,
  204. HeadersOverride: map[string]interface{}{
  205. "X-Static": "legacy-static",
  206. },
  207. },
  208. }
  209. _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5"}`), info)
  210. require.NoError(t, err)
  211. require.True(t, info.UseRuntimeHeadersOverride)
  212. require.Equal(t, "legacy-static", info.RuntimeHeadersOverride["x-static"])
  213. require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
  214. require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
  215. require.Equal(t, "codex-cli-test", info.RuntimeHeadersOverride["user-agent"])
  216. _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
  217. require.False(t, exists)
  218. _, exists = info.RuntimeHeadersOverride["x-codex-turn-metadata"]
  219. require.False(t, exists)
  220. }