relay-utils.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/pkoukk/tiktoken-go"
  5. "one-api/common"
  6. )
  7. var stopFinishReason = "stop"
  8. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  9. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  10. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  11. return tokenEncoder
  12. }
  13. tokenEncoder, err := tiktoken.EncodingForModel(model)
  14. if err != nil {
  15. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  16. tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
  17. if err != nil {
  18. common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
  19. }
  20. }
  21. tokenEncoderMap[model] = tokenEncoder
  22. return tokenEncoder
  23. }
  24. func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
  25. if common.ApproximateTokenEnabled {
  26. return int(float64(len(text)) * 0.38)
  27. }
  28. return len(tokenEncoder.Encode(text, nil, nil))
  29. }
  30. func countTokenMessages(messages []Message, model string) int {
  31. tokenEncoder := getTokenEncoder(model)
  32. // Reference:
  33. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  34. // https://github.com/pkoukk/tiktoken-go/issues/6
  35. //
  36. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  37. var tokensPerMessage int
  38. var tokensPerName int
  39. if model == "gpt-3.5-turbo-0301" {
  40. tokensPerMessage = 4
  41. tokensPerName = -1 // If there's a name, the role is omitted
  42. } else {
  43. tokensPerMessage = 3
  44. tokensPerName = 1
  45. }
  46. tokenNum := 0
  47. for _, message := range messages {
  48. tokenNum += tokensPerMessage
  49. tokenNum += getTokenNum(tokenEncoder, message.Content)
  50. tokenNum += getTokenNum(tokenEncoder, message.Role)
  51. if message.Name != nil {
  52. tokenNum += tokensPerName
  53. tokenNum += getTokenNum(tokenEncoder, *message.Name)
  54. }
  55. }
  56. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  57. return tokenNum
  58. }
  59. func countTokenInput(input any, model string) int {
  60. switch input.(type) {
  61. case string:
  62. return countTokenText(input.(string), model)
  63. case []string:
  64. text := ""
  65. for _, s := range input.([]string) {
  66. text += s
  67. }
  68. return countTokenText(text, model)
  69. }
  70. return 0
  71. }
  72. func countTokenText(text string, model string) int {
  73. tokenEncoder := getTokenEncoder(model)
  74. return getTokenNum(tokenEncoder, text)
  75. }
  76. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  77. openAIError := OpenAIError{
  78. Message: err.Error(),
  79. Type: "one_api_error",
  80. Code: code,
  81. }
  82. return &OpenAIErrorWithStatusCode{
  83. OpenAIError: openAIError,
  84. StatusCode: statusCode,
  85. }
  86. }
  87. func shouldDisableChannel(err *OpenAIError) bool {
  88. if !common.AutomaticDisableChannelEnabled {
  89. return false
  90. }
  91. if err == nil {
  92. return false
  93. }
  94. if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
  95. return true
  96. }
  97. return false
  98. }