relay-utils.go 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/pkoukk/tiktoken-go"
  5. "one-api/common"
  6. "strings"
  7. )
  8. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  9. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  10. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  11. return tokenEncoder
  12. }
  13. tokenEncoder, err := tiktoken.EncodingForModel(model)
  14. if err != nil {
  15. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  16. tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
  17. if err != nil {
  18. common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
  19. }
  20. }
  21. tokenEncoderMap[model] = tokenEncoder
  22. return tokenEncoder
  23. }
  24. func countTokenMessages(messages []Message, model string) int {
  25. tokenEncoder := getTokenEncoder(model)
  26. // Reference:
  27. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  28. // https://github.com/pkoukk/tiktoken-go/issues/6
  29. //
  30. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  31. var tokensPerMessage int
  32. var tokensPerName int
  33. if strings.HasPrefix(model, "gpt-3.5") {
  34. tokensPerMessage = 4
  35. tokensPerName = -1 // If there's a name, the role is omitted
  36. } else if strings.HasPrefix(model, "gpt-4") {
  37. tokensPerMessage = 3
  38. tokensPerName = 1
  39. } else {
  40. tokensPerMessage = 3
  41. tokensPerName = 1
  42. }
  43. tokenNum := 0
  44. for _, message := range messages {
  45. tokenNum += tokensPerMessage
  46. tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
  47. tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
  48. if message.Name != nil {
  49. tokenNum += tokensPerName
  50. tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
  51. }
  52. }
  53. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  54. return tokenNum
  55. }
  56. func countTokenInput(input any, model string) int {
  57. switch input.(type) {
  58. case string:
  59. return countTokenText(input.(string), model)
  60. case []string:
  61. text := ""
  62. for _, s := range input.([]string) {
  63. text += s
  64. }
  65. return countTokenText(text, model)
  66. }
  67. return 0
  68. }
  69. func countTokenText(text string, model string) int {
  70. tokenEncoder := getTokenEncoder(model)
  71. token := tokenEncoder.Encode(text, nil, nil)
  72. return len(token)
  73. }
  74. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  75. openAIError := OpenAIError{
  76. Message: err.Error(),
  77. Type: "one_api_error",
  78. Code: code,
  79. }
  80. return &OpenAIErrorWithStatusCode{
  81. OpenAIError: openAIError,
  82. StatusCode: statusCode,
  83. }
  84. }