relay-utils.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/pkoukk/tiktoken-go"
  5. "one-api/common"
  6. )
  7. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  8. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  9. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  10. return tokenEncoder
  11. }
  12. tokenEncoder, err := tiktoken.EncodingForModel(model)
  13. if err != nil {
  14. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  15. tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
  16. if err != nil {
  17. common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
  18. }
  19. }
  20. tokenEncoderMap[model] = tokenEncoder
  21. return tokenEncoder
  22. }
  23. func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
  24. if common.ApproximateTokenEnabled {
  25. return int(float64(len(text)) * 0.38)
  26. }
  27. return len(tokenEncoder.Encode(text, nil, nil))
  28. }
  29. func countTokenMessages(messages []Message, model string) int {
  30. tokenEncoder := getTokenEncoder(model)
  31. // Reference:
  32. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  33. // https://github.com/pkoukk/tiktoken-go/issues/6
  34. //
  35. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  36. var tokensPerMessage int
  37. var tokensPerName int
  38. if model == "gpt-3.5-turbo-0301" {
  39. tokensPerMessage = 4
  40. tokensPerName = -1 // If there's a name, the role is omitted
  41. } else {
  42. tokensPerMessage = 3
  43. tokensPerName = 1
  44. }
  45. tokenNum := 0
  46. for _, message := range messages {
  47. tokenNum += tokensPerMessage
  48. tokenNum += getTokenNum(tokenEncoder, message.Content)
  49. tokenNum += getTokenNum(tokenEncoder, message.Role)
  50. if message.Name != nil {
  51. tokenNum += tokensPerName
  52. tokenNum += getTokenNum(tokenEncoder, *message.Name)
  53. }
  54. }
  55. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  56. return tokenNum
  57. }
  58. func countTokenInput(input any, model string) int {
  59. switch input.(type) {
  60. case string:
  61. return countTokenText(input.(string), model)
  62. case []string:
  63. text := ""
  64. for _, s := range input.([]string) {
  65. text += s
  66. }
  67. return countTokenText(text, model)
  68. }
  69. return 0
  70. }
  71. func countTokenText(text string, model string) int {
  72. tokenEncoder := getTokenEncoder(model)
  73. return getTokenNum(tokenEncoder, text)
  74. }
  75. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  76. openAIError := OpenAIError{
  77. Message: err.Error(),
  78. Type: "one_api_error",
  79. Code: code,
  80. }
  81. return &OpenAIErrorWithStatusCode{
  82. OpenAIError: openAIError,
  83. StatusCode: statusCode,
  84. }
  85. }
  86. func shouldDisableChannel(err *OpenAIError) bool {
  87. if !common.AutomaticDisableChannelEnabled {
  88. return false
  89. }
  90. if err == nil {
  91. return false
  92. }
  93. if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
  94. return true
  95. }
  96. return false
  97. }