chat_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package model_test
  2. import (
  3. "errors"
  4. "net/http"
  5. "testing"
  6. "github.com/labring/aiproxy/core/relay/model"
  7. "github.com/smartystreets/goconvey/convey"
  8. )
  9. func TestChatUsage(t *testing.T) {
  10. convey.Convey("ChatUsage", t, func() {
  11. convey.Convey("ToModelUsage", func() {
  12. u := model.ChatUsage{
  13. PromptTokens: 10,
  14. CompletionTokens: 20,
  15. TotalTokens: 30,
  16. WebSearchCount: 5,
  17. PromptTokensDetails: &model.PromptTokensDetails{
  18. CachedTokens: 5,
  19. CacheCreationTokens: 2,
  20. },
  21. CompletionTokensDetails: &model.CompletionTokensDetails{
  22. ReasoningTokens: 10,
  23. },
  24. }
  25. modelUsage := u.ToModelUsage()
  26. convey.So(int64(modelUsage.InputTokens), convey.ShouldEqual, 10)
  27. convey.So(int64(modelUsage.OutputTokens), convey.ShouldEqual, 20)
  28. convey.So(int64(modelUsage.TotalTokens), convey.ShouldEqual, 30)
  29. convey.So(int64(modelUsage.WebSearchCount), convey.ShouldEqual, 5)
  30. convey.So(int64(modelUsage.CachedTokens), convey.ShouldEqual, 5)
  31. convey.So(int64(modelUsage.CacheCreationTokens), convey.ShouldEqual, 2)
  32. convey.So(int64(modelUsage.ReasoningTokens), convey.ShouldEqual, 10)
  33. })
  34. convey.Convey("Add", func() {
  35. u1 := model.ChatUsage{
  36. PromptTokens: 10,
  37. CompletionTokens: 20,
  38. TotalTokens: 30,
  39. PromptTokensDetails: &model.PromptTokensDetails{
  40. CachedTokens: 5,
  41. },
  42. }
  43. u2 := model.ChatUsage{
  44. PromptTokens: 5,
  45. CompletionTokens: 5,
  46. TotalTokens: 10,
  47. PromptTokensDetails: &model.PromptTokensDetails{
  48. CachedTokens: 2,
  49. },
  50. }
  51. u1.Add(&u2)
  52. convey.So(u1.PromptTokens, convey.ShouldEqual, 15)
  53. convey.So(u1.CompletionTokens, convey.ShouldEqual, 25)
  54. convey.So(u1.TotalTokens, convey.ShouldEqual, 40)
  55. convey.So(u1.PromptTokensDetails.CachedTokens, convey.ShouldEqual, 7)
  56. // Add nil
  57. u1.Add(nil)
  58. convey.So(u1.TotalTokens, convey.ShouldEqual, 40)
  59. })
  60. convey.Convey("ToClaudeUsage", func() {
  61. u := model.ChatUsage{
  62. PromptTokens: 10,
  63. CompletionTokens: 20,
  64. PromptTokensDetails: &model.PromptTokensDetails{
  65. CachedTokens: 5,
  66. CacheCreationTokens: 2,
  67. },
  68. }
  69. cu := u.ToClaudeUsage()
  70. convey.So(cu.InputTokens, convey.ShouldEqual, 10)
  71. convey.So(cu.OutputTokens, convey.ShouldEqual, 20)
  72. convey.So(cu.CacheReadInputTokens, convey.ShouldEqual, 5)
  73. convey.So(cu.CacheCreationInputTokens, convey.ShouldEqual, 2)
  74. })
  75. })
  76. }
  77. func TestOpenAIError(t *testing.T) {
  78. convey.Convey("OpenAIError", t, func() {
  79. convey.Convey("NewOpenAIError", func() {
  80. err := model.OpenAIError{
  81. Message: "test error",
  82. Type: "test_type",
  83. Code: "test_code",
  84. }
  85. resp := model.NewOpenAIError(http.StatusBadRequest, err)
  86. convey.So(resp.StatusCode(), convey.ShouldEqual, http.StatusBadRequest)
  87. // The Error field is unexported or nested, but NewOpenAIError returns adaptor.Error interface (or struct?)
  88. // Let's check what adaptor.Error exposes.
  89. // It usually exposes Error() string.
  90. })
  91. convey.Convey("WrapperOpenAIError", func() {
  92. err := errors.New("base error")
  93. resp := model.WrapperOpenAIError(err, "code_123", http.StatusInternalServerError)
  94. convey.So(resp.StatusCode(), convey.ShouldEqual, http.StatusInternalServerError)
  95. })
  96. })
  97. }