patch_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. package patch_test
  2. import (
  3. "testing"
  4. "github.com/bytedance/sonic"
  5. "github.com/labring/aiproxy/core/relay/meta"
  6. "github.com/labring/aiproxy/core/relay/plugin/patch"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestNew(t *testing.T) {
  11. plugin := patch.NewPatchPlugin()
  12. assert.NotNil(t, plugin)
  13. assert.True(t, len(patch.DefaultPredefinedPatches) > 0)
  14. }
  15. func TestApplyPatches_DeepSeekMaxTokensLimit(t *testing.T) {
  16. plugin := patch.NewPatchPlugin()
  17. config := &patch.Config{}
  18. testCases := []struct {
  19. name string
  20. input map[string]any
  21. actualModel string
  22. expectedMaxTokens int
  23. shouldModify bool
  24. }{
  25. {
  26. name: "deepseek model with high max_tokens",
  27. input: map[string]any{
  28. "model": "deepseek-chat",
  29. "max_tokens": 20000,
  30. },
  31. actualModel: "deepseek-chat",
  32. expectedMaxTokens: 8192,
  33. shouldModify: true,
  34. },
  35. {
  36. name: "deepseek model with high max_tokens",
  37. input: map[string]any{
  38. "model": "deepseek-v3",
  39. "max_tokens": 20000,
  40. },
  41. actualModel: "deepseek-v3",
  42. expectedMaxTokens: 16384,
  43. shouldModify: true,
  44. },
  45. {
  46. name: "deepseek model with low max_tokens",
  47. input: map[string]any{
  48. "model": "deepseek-chat",
  49. "max_tokens": 8000,
  50. },
  51. actualModel: "deepseek-chat",
  52. expectedMaxTokens: 8000,
  53. shouldModify: false,
  54. },
  55. {
  56. name: "non-deepseek model",
  57. input: map[string]any{
  58. "model": "gpt-4",
  59. "max_tokens": 20000,
  60. },
  61. actualModel: "gpt-4",
  62. expectedMaxTokens: 20000,
  63. shouldModify: false,
  64. },
  65. }
  66. for _, tc := range testCases {
  67. t.Run(tc.name, func(t *testing.T) {
  68. inputBytes, err := sonic.Marshal(tc.input)
  69. require.NoError(t, err)
  70. meta := &meta.Meta{ActualModel: tc.actualModel}
  71. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  72. require.NoError(t, err)
  73. assert.Equal(t, tc.shouldModify, modified)
  74. var output map[string]any
  75. err = sonic.Unmarshal(outputBytes, &output)
  76. require.NoError(t, err)
  77. if maxTokens, exists := output["max_tokens"]; exists {
  78. maxTokensFloat, ok := maxTokens.(float64)
  79. require.True(t, ok, "max_tokens should be float64")
  80. assert.Equal(t, tc.expectedMaxTokens, int(maxTokensFloat))
  81. }
  82. })
  83. }
  84. }
  85. func TestApplyPatches_GPT5MaxTokensConversion(t *testing.T) {
  86. plugin := patch.NewPatchPlugin()
  87. config := &patch.Config{}
  88. testCases := []struct {
  89. name string
  90. input map[string]any
  91. actualModel string
  92. expectedMaxCompletionTokens int
  93. shouldHaveMaxTokens bool
  94. shouldModify bool
  95. shouldHaveMaxCompletionTokens bool
  96. }{
  97. {
  98. name: "gpt-5 model with max_tokens",
  99. input: map[string]any{
  100. "model": "gpt-5",
  101. "max_tokens": 4000,
  102. "temperature": 0.7,
  103. },
  104. actualModel: "gpt-5",
  105. expectedMaxCompletionTokens: 4000,
  106. shouldHaveMaxTokens: false,
  107. shouldModify: true,
  108. shouldHaveMaxCompletionTokens: true,
  109. },
  110. {
  111. name: "gpt-5 model without max_tokens",
  112. input: map[string]any{
  113. "model": "gpt-5",
  114. "temperature": 0.7,
  115. },
  116. actualModel: "gpt-5",
  117. shouldHaveMaxTokens: false,
  118. shouldModify: true,
  119. shouldHaveMaxCompletionTokens: false,
  120. },
  121. {
  122. name: "gpt-4 model with max_tokens",
  123. input: map[string]any{
  124. "model": "gpt-4",
  125. "max_tokens": 4000,
  126. },
  127. actualModel: "gpt-4",
  128. shouldHaveMaxTokens: true,
  129. shouldModify: false,
  130. shouldHaveMaxCompletionTokens: false,
  131. },
  132. }
  133. for _, tc := range testCases {
  134. t.Run(tc.name, func(t *testing.T) {
  135. inputBytes, err := sonic.Marshal(tc.input)
  136. require.NoError(t, err)
  137. meta := &meta.Meta{ActualModel: tc.actualModel}
  138. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  139. require.NoError(t, err)
  140. assert.Equal(t, tc.shouldModify, modified)
  141. var output map[string]any
  142. err = sonic.Unmarshal(outputBytes, &output)
  143. require.NoError(t, err)
  144. if tc.shouldHaveMaxCompletionTokens {
  145. maxCompletionTokens, ok := output["max_completion_tokens"].(float64)
  146. require.True(t, ok, "max_completion_tokens should be float64")
  147. assert.Equal(
  148. t,
  149. tc.expectedMaxCompletionTokens,
  150. int(maxCompletionTokens),
  151. )
  152. } else {
  153. _, hasMaxCompletionTokens := output["max_completion_tokens"]
  154. assert.False(t, hasMaxCompletionTokens, "max_completion_tokens should not exist")
  155. }
  156. _, hasMaxTokens := output["max_tokens"]
  157. assert.Equal(t, tc.shouldHaveMaxTokens, hasMaxTokens)
  158. })
  159. }
  160. }
  161. func TestCustomUserPatches(t *testing.T) {
  162. plugin := patch.NewPatchPlugin()
  163. config := &patch.Config{
  164. UserPatches: []patch.PatchRule{
  165. {
  166. Name: "test_temperature_limit",
  167. Conditions: []patch.PatchCondition{
  168. {
  169. Key: "model",
  170. Operator: patch.OperatorContains,
  171. Value: "test",
  172. },
  173. },
  174. Operations: []patch.PatchOperation{
  175. {
  176. Op: patch.OpLimit,
  177. Key: "temperature",
  178. Value: 1.0,
  179. },
  180. },
  181. },
  182. {
  183. Name: "add_default_top_p",
  184. Conditions: []patch.PatchCondition{
  185. {
  186. Key: "top_p",
  187. Operator: patch.OperatorNotExists,
  188. Value: "",
  189. },
  190. },
  191. Operations: []patch.PatchOperation{
  192. {
  193. Op: patch.OpAdd,
  194. Key: "top_p",
  195. Value: 0.9,
  196. },
  197. },
  198. },
  199. },
  200. }
  201. // Test temperature limit
  202. input := map[string]any{
  203. "model": "test-model",
  204. "temperature": 1.5,
  205. }
  206. inputBytes, err := sonic.Marshal(input)
  207. require.NoError(t, err)
  208. meta := &meta.Meta{ActualModel: "test-model"}
  209. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  210. require.NoError(t, err)
  211. assert.True(t, modified)
  212. var output map[string]any
  213. err = sonic.Unmarshal(outputBytes, &output)
  214. require.NoError(t, err)
  215. assert.Equal(t, 1.0, output["temperature"])
  216. assert.Equal(t, 0.9, output["top_p"]) // Should be added
  217. }
  218. func TestNestedFieldOperations(t *testing.T) {
  219. plugin := patch.NewPatchPlugin()
  220. config := &patch.Config{
  221. UserPatches: []patch.PatchRule{
  222. {
  223. Name: "nested_operations",
  224. Operations: []patch.PatchOperation{
  225. {
  226. Op: patch.OpSet,
  227. Key: "parameters.max_tokens",
  228. Value: 2000,
  229. },
  230. {
  231. Op: patch.OpSet,
  232. Key: "metadata.version",
  233. Value: "1.0",
  234. },
  235. },
  236. },
  237. },
  238. }
  239. input := map[string]any{
  240. "model": "test",
  241. "parameters": map[string]any{
  242. "temperature": 0.7,
  243. },
  244. }
  245. inputBytes, err := sonic.Marshal(input)
  246. require.NoError(t, err)
  247. meta := &meta.Meta{ActualModel: "test"}
  248. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  249. require.NoError(t, err)
  250. assert.True(t, modified)
  251. var output map[string]any
  252. err = sonic.Unmarshal(outputBytes, &output)
  253. require.NoError(t, err)
  254. // Check nested field access
  255. params, ok := output["parameters"].(map[string]any)
  256. require.True(t, ok)
  257. maxTokens, ok := params["max_tokens"].(float64)
  258. require.True(t, ok, "max_tokens should be float64")
  259. assert.Equal(t, 2000, int(maxTokens))
  260. assert.Equal(t, 0.7, params["temperature"])
  261. metadata, ok := output["metadata"].(map[string]any)
  262. require.True(t, ok)
  263. assert.Equal(t, "1.0", metadata["version"])
  264. }
  265. func TestPlaceholderResolution(t *testing.T) {
  266. plugin := patch.NewPatchPlugin()
  267. config := &patch.Config{
  268. UserPatches: []patch.PatchRule{
  269. {
  270. Name: "placeholder_test",
  271. Conditions: []patch.PatchCondition{
  272. {
  273. Key: "max_tokens",
  274. Operator: patch.OperatorExists,
  275. },
  276. },
  277. Operations: []patch.PatchOperation{
  278. {
  279. Op: patch.OpSet,
  280. Key: "max_completion_tokens",
  281. Value: "{{max_tokens}}",
  282. },
  283. {
  284. Op: patch.OpDelete,
  285. Key: "max_tokens",
  286. },
  287. },
  288. },
  289. },
  290. }
  291. input := map[string]any{
  292. "model": "test",
  293. "max_tokens": 3000,
  294. }
  295. inputBytes, err := sonic.Marshal(input)
  296. require.NoError(t, err)
  297. meta := &meta.Meta{ActualModel: "test"}
  298. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  299. require.NoError(t, err)
  300. assert.True(t, modified)
  301. var output map[string]any
  302. err = sonic.Unmarshal(outputBytes, &output)
  303. require.NoError(t, err)
  304. maxCompletionTokens, ok := output["max_completion_tokens"].(float64)
  305. require.True(t, ok, "max_completion_tokens should be float64")
  306. assert.Equal(t, 3000, int(maxCompletionTokens))
  307. _, hasMaxTokens := output["max_tokens"]
  308. assert.False(t, hasMaxTokens)
  309. }
  310. func TestOperators(t *testing.T) {
  311. plugin := patch.NewPatchPlugin()
  312. config := &patch.Config{
  313. UserPatches: []patch.PatchRule{
  314. {
  315. Name: "operator_tests",
  316. Conditions: []patch.PatchCondition{
  317. {
  318. Key: "model",
  319. Operator: patch.OperatorRegex,
  320. Value: "^gpt-[0-9]$",
  321. },
  322. },
  323. Operations: []patch.PatchOperation{
  324. {
  325. Op: patch.OpSet,
  326. Key: "matched",
  327. Value: true,
  328. },
  329. },
  330. },
  331. },
  332. }
  333. testCases := []struct {
  334. model string
  335. shouldMatch bool
  336. }{
  337. {"gpt-4", true},
  338. {"gpt-3", true},
  339. {"gpt-4o", false},
  340. {"claude-3", false},
  341. }
  342. for _, tc := range testCases {
  343. t.Run(tc.model, func(t *testing.T) {
  344. input := map[string]any{"model": tc.model}
  345. inputBytes, err := sonic.Marshal(input)
  346. require.NoError(t, err)
  347. meta := &meta.Meta{ActualModel: tc.model}
  348. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, config)
  349. require.NoError(t, err)
  350. assert.Equal(t, tc.shouldMatch, modified)
  351. if tc.shouldMatch {
  352. var output map[string]any
  353. err = sonic.Unmarshal(outputBytes, &output)
  354. require.NoError(t, err)
  355. matched, ok := output["matched"].(bool)
  356. require.True(t, ok, "matched should be bool")
  357. assert.True(t, matched)
  358. }
  359. })
  360. }
  361. }
  362. func TestInvalidJSON(t *testing.T) {
  363. plugin := patch.NewPatchPlugin()
  364. config := &patch.Config{}
  365. invalidJSON := []byte(`{"invalid": json}`)
  366. meta := &meta.Meta{ActualModel: "test"}
  367. outputBytes, modified, err := plugin.ApplyPatches(invalidJSON, meta, config)
  368. require.NoError(t, err)
  369. assert.False(t, modified)
  370. assert.Equal(t, invalidJSON, outputBytes)
  371. }
  372. func TestConvertRequest(t *testing.T) {
  373. // Skip this test since it requires database initialization
  374. // The functionality is already tested in other unit tests
  375. t.Skip("Skipping integration test - requires database setup")
  376. }
  377. func TestToFloat64(t *testing.T) {
  378. testCases := []struct {
  379. input any
  380. expected float64
  381. hasError bool
  382. }{
  383. {float64(3.14), 3.14, false},
  384. {float32(2.5), 2.5, false},
  385. {int(42), 42.0, false},
  386. {int32(100), 100.0, false},
  387. {int64(200), 200.0, false},
  388. {"123.45", 123.45, false},
  389. {"invalid", 0, true},
  390. {true, 0, true},
  391. }
  392. for _, tc := range testCases {
  393. result, err := patch.ToFloat64(tc.input)
  394. if tc.hasError {
  395. assert.Error(t, err)
  396. } else {
  397. assert.NoError(t, err)
  398. assert.Equal(t, tc.expected, result)
  399. }
  400. }
  401. }
  402. func TestConditionLogicOperators(t *testing.T) {
  403. plugin := patch.NewPatchPlugin()
  404. testCases := []struct {
  405. name string
  406. config *patch.Config
  407. input map[string]any
  408. actualModel string
  409. shouldModify bool
  410. }{
  411. {
  412. name: "OR logic - one condition matches",
  413. config: &patch.Config{
  414. UserPatches: []patch.PatchRule{
  415. {
  416. Name: "or_logic_test",
  417. ConditionLogic: patch.LogicOr,
  418. Conditions: []patch.PatchCondition{
  419. {
  420. Key: "model",
  421. Operator: patch.OperatorEquals,
  422. Value: "gpt-4",
  423. },
  424. {
  425. Key: "temperature",
  426. Operator: patch.OperatorGreaterThan,
  427. Value: "1.5",
  428. },
  429. },
  430. Operations: []patch.PatchOperation{
  431. {
  432. Op: patch.OpSet,
  433. Key: "modified",
  434. Value: true,
  435. },
  436. },
  437. },
  438. },
  439. },
  440. input: map[string]any{
  441. "model": "claude-3",
  442. "temperature": 2.0,
  443. },
  444. actualModel: "claude-3",
  445. shouldModify: true,
  446. },
  447. {
  448. name: "OR logic - no condition matches",
  449. config: &patch.Config{
  450. UserPatches: []patch.PatchRule{
  451. {
  452. Name: "or_logic_test_no_match",
  453. ConditionLogic: patch.LogicOr,
  454. Conditions: []patch.PatchCondition{
  455. {
  456. Key: "model",
  457. Operator: patch.OperatorEquals,
  458. Value: "gpt-4",
  459. },
  460. {
  461. Key: "temperature",
  462. Operator: patch.OperatorGreaterThan,
  463. Value: "1.5",
  464. },
  465. },
  466. Operations: []patch.PatchOperation{
  467. {
  468. Op: patch.OpSet,
  469. Key: "modified",
  470. Value: true,
  471. },
  472. },
  473. },
  474. },
  475. },
  476. input: map[string]any{
  477. "model": "claude-3",
  478. "temperature": 1.0,
  479. },
  480. actualModel: "claude-3",
  481. shouldModify: false,
  482. },
  483. {
  484. name: "AND logic (default) - all conditions match",
  485. config: &patch.Config{
  486. UserPatches: []patch.PatchRule{
  487. {
  488. Name: "and_logic_test",
  489. Conditions: []patch.PatchCondition{
  490. {
  491. Key: "model",
  492. Operator: patch.OperatorContains,
  493. Value: "gpt",
  494. },
  495. {
  496. Key: "temperature",
  497. Operator: patch.OperatorLessThan,
  498. Value: "1.5",
  499. },
  500. },
  501. Operations: []patch.PatchOperation{
  502. {
  503. Op: patch.OpSet,
  504. Key: "modified",
  505. Value: true,
  506. },
  507. },
  508. },
  509. },
  510. },
  511. input: map[string]any{
  512. "model": "gpt-4",
  513. "temperature": 1.0,
  514. },
  515. actualModel: "gpt-4",
  516. shouldModify: true,
  517. },
  518. {
  519. name: "AND logic (default) - one condition fails",
  520. config: &patch.Config{
  521. UserPatches: []patch.PatchRule{
  522. {
  523. Name: "and_logic_test_fail",
  524. Conditions: []patch.PatchCondition{
  525. {
  526. Key: "model",
  527. Operator: patch.OperatorContains,
  528. Value: "gpt",
  529. },
  530. {
  531. Key: "temperature",
  532. Operator: patch.OperatorLessThan,
  533. Value: "1.5",
  534. },
  535. },
  536. Operations: []patch.PatchOperation{
  537. {
  538. Op: patch.OpSet,
  539. Key: "modified",
  540. Value: true,
  541. },
  542. },
  543. },
  544. },
  545. },
  546. input: map[string]any{
  547. "model": "gpt-4",
  548. "temperature": 2.0,
  549. },
  550. actualModel: "gpt-4",
  551. shouldModify: false,
  552. },
  553. }
  554. for _, tc := range testCases {
  555. t.Run(tc.name, func(t *testing.T) {
  556. inputBytes, err := sonic.Marshal(tc.input)
  557. require.NoError(t, err)
  558. meta := &meta.Meta{ActualModel: tc.actualModel}
  559. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, tc.config)
  560. require.NoError(t, err)
  561. assert.Equal(t, tc.shouldModify, modified)
  562. if tc.shouldModify {
  563. var output map[string]any
  564. err = sonic.Unmarshal(outputBytes, &output)
  565. require.NoError(t, err)
  566. assert.Equal(t, output["modified"], true)
  567. }
  568. })
  569. }
  570. }
  571. func TestConditionNegation(t *testing.T) {
  572. plugin := patch.NewPatchPlugin()
  573. testCases := []struct {
  574. name string
  575. config *patch.Config
  576. input map[string]any
  577. actualModel string
  578. shouldModify bool
  579. }{
  580. {
  581. name: "negate condition - should match when negated",
  582. config: &patch.Config{
  583. UserPatches: []patch.PatchRule{
  584. {
  585. Name: "negate_test",
  586. Conditions: []patch.PatchCondition{
  587. {
  588. Key: "model",
  589. Operator: patch.OperatorEquals,
  590. Value: "gpt-4",
  591. Negate: true,
  592. },
  593. },
  594. Operations: []patch.PatchOperation{
  595. {
  596. Op: patch.OpSet,
  597. Key: "modified",
  598. Value: true,
  599. },
  600. },
  601. },
  602. },
  603. },
  604. input: map[string]any{
  605. "model": "claude-3",
  606. },
  607. actualModel: "claude-3",
  608. shouldModify: true,
  609. },
  610. {
  611. name: "negate condition - should not match when negated",
  612. config: &patch.Config{
  613. UserPatches: []patch.PatchRule{
  614. {
  615. Name: "negate_test_no_match",
  616. Conditions: []patch.PatchCondition{
  617. {
  618. Key: "model",
  619. Operator: patch.OperatorEquals,
  620. Value: "gpt-4",
  621. Negate: true,
  622. },
  623. },
  624. Operations: []patch.PatchOperation{
  625. {
  626. Op: patch.OpSet,
  627. Key: "modified",
  628. Value: true,
  629. },
  630. },
  631. },
  632. },
  633. },
  634. input: map[string]any{
  635. "model": "gpt-4",
  636. },
  637. actualModel: "gpt-4",
  638. shouldModify: false,
  639. },
  640. {
  641. name: "OR with negation - complex logic",
  642. config: &patch.Config{
  643. UserPatches: []patch.PatchRule{
  644. {
  645. Name: "or_with_negate",
  646. ConditionLogic: patch.LogicOr,
  647. Conditions: []patch.PatchCondition{
  648. {
  649. Key: "model",
  650. Operator: patch.OperatorEquals,
  651. Value: "gpt-4",
  652. },
  653. {
  654. Key: "temperature",
  655. Operator: patch.OperatorExists,
  656. Negate: true, // NOT exists
  657. },
  658. },
  659. Operations: []patch.PatchOperation{
  660. {
  661. Op: patch.OpSet,
  662. Key: "modified",
  663. Value: true,
  664. },
  665. },
  666. },
  667. },
  668. },
  669. input: map[string]any{
  670. "model": "claude-3",
  671. // no temperature field
  672. },
  673. actualModel: "claude-3",
  674. shouldModify: true, // Should match because temperature doesn't exist (negated exists)
  675. },
  676. }
  677. for _, tc := range testCases {
  678. t.Run(tc.name, func(t *testing.T) {
  679. inputBytes, err := sonic.Marshal(tc.input)
  680. require.NoError(t, err)
  681. meta := &meta.Meta{ActualModel: tc.actualModel}
  682. outputBytes, modified, err := plugin.ApplyPatches(inputBytes, meta, tc.config)
  683. require.NoError(t, err)
  684. assert.Equal(t, tc.shouldModify, modified)
  685. if tc.shouldModify {
  686. var output map[string]any
  687. err = sonic.Unmarshal(outputBytes, &output)
  688. require.NoError(t, err)
  689. assert.Equal(t, output["modified"], true)
  690. }
  691. })
  692. }
  693. }