tiktoken_test.go 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. package tiktoken_test
  2. import (
  3. "testing"
  4. "github.com/labring/aiproxy/core/common/tiktoken"
  5. "github.com/smartystreets/goconvey/convey"
  6. )
  7. func TestGetTokenEncoder(t *testing.T) {
  8. convey.Convey("GetTokenEncoder", t, func() {
  9. convey.Convey("should get encoder for gpt-4o", func() {
  10. enc := tiktoken.GetTokenEncoder("gpt-4o")
  11. convey.So(enc, convey.ShouldNotBeNil)
  12. })
  13. convey.Convey("should get encoder for gpt-3.5-turbo", func() {
  14. enc := tiktoken.GetTokenEncoder("gpt-3.5-turbo")
  15. convey.So(enc, convey.ShouldNotBeNil)
  16. })
  17. convey.Convey("should return default encoder for unknown model", func() {
  18. enc := tiktoken.GetTokenEncoder("unknown-model")
  19. convey.So(enc, convey.ShouldNotBeNil)
  20. // Should default to gpt-4o encoder (o200k_base)
  21. ids, _, _ := enc.Encode("hello")
  22. convey.So(len(ids), convey.ShouldBeGreaterThan, 0)
  23. })
  24. convey.Convey("should cache encoders", func() {
  25. enc1 := tiktoken.GetTokenEncoder("gpt-4")
  26. enc2 := tiktoken.GetTokenEncoder("gpt-4")
  27. convey.So(enc1, convey.ShouldEqual, enc2)
  28. })
  29. })
  30. }
  31. func TestEncoding(t *testing.T) {
  32. convey.Convey("Encoding", t, func() {
  33. convey.Convey("should encode correctly", func() {
  34. enc := tiktoken.GetTokenEncoder("gpt-3.5-turbo")
  35. text := "hello world"
  36. ids, _, err := enc.Encode(text)
  37. convey.So(err, convey.ShouldBeNil)
  38. convey.So(len(ids), convey.ShouldBeGreaterThan, 0)
  39. decoded, err := enc.Decode(ids)
  40. convey.So(err, convey.ShouldBeNil)
  41. convey.So(decoded, convey.ShouldEqual, text)
  42. })
  43. })
  44. }