| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- package controller
- import (
- "fmt"
- "github.com/pkoukk/tiktoken-go"
- "one-api/common"
- )
- var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
- func getTokenEncoder(model string) *tiktoken.Tiktoken {
- if tokenEncoder, ok := tokenEncoderMap[model]; ok {
- return tokenEncoder
- }
- tokenEncoder, err := tiktoken.EncodingForModel(model)
- if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
- tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
- }
- }
- tokenEncoderMap[model] = tokenEncoder
- return tokenEncoder
- }
- func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
- if common.ApproximateTokenEnabled {
- return int(float64(len(text)) * 0.38)
- }
- return len(tokenEncoder.Encode(text, nil, nil))
- }
- func countTokenMessages(messages []Message, model string) int {
- tokenEncoder := getTokenEncoder(model)
- // Reference:
- // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- // https://github.com/pkoukk/tiktoken-go/issues/6
- //
- // Every message follows <|start|>{role/name}\n{content}<|end|>\n
- var tokensPerMessage int
- var tokensPerName int
- if model == "gpt-3.5-turbo-0301" {
- tokensPerMessage = 4
- tokensPerName = -1 // If there's a name, the role is omitted
- } else {
- tokensPerMessage = 3
- tokensPerName = 1
- }
- tokenNum := 0
- for _, message := range messages {
- tokenNum += tokensPerMessage
- tokenNum += getTokenNum(tokenEncoder, message.Content)
- tokenNum += getTokenNum(tokenEncoder, message.Role)
- if message.Name != nil {
- tokenNum += tokensPerName
- tokenNum += getTokenNum(tokenEncoder, *message.Name)
- }
- }
- tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
- return tokenNum
- }
- func countTokenInput(input any, model string) int {
- switch input.(type) {
- case string:
- return countTokenText(input.(string), model)
- case []string:
- text := ""
- for _, s := range input.([]string) {
- text += s
- }
- return countTokenText(text, model)
- }
- return 0
- }
- func countTokenText(text string, model string) int {
- tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text)
- }
- func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
- openAIError := OpenAIError{
- Message: err.Error(),
- Type: "one_api_error",
- Code: code,
- }
- return &OpenAIErrorWithStatusCode{
- OpenAIError: openAIError,
- StatusCode: statusCode,
- }
- }
- func shouldDisableChannel(err *OpenAIError) bool {
- if !common.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
- return true
- }
- return false
- }
|