| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- package controller
- import (
- "encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/pkoukk/tiktoken-go"
- "io"
- "net/http"
- "one-api/common"
- "strconv"
- "strings"
- "unicode/utf8"
- )
- var stopFinishReason = "stop"
- // tokenEncoderMap won't grow after initialization
- var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
- var defaultTokenEncoder *tiktoken.Tiktoken
- func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
- gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
- }
- defaultTokenEncoder = gpt35TokenEncoder
- gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
- }
- for model, _ := range common.ModelRatio {
- if strings.HasPrefix(model, "gpt-3.5") {
- tokenEncoderMap[model] = gpt35TokenEncoder
- } else if strings.HasPrefix(model, "gpt-4") {
- tokenEncoderMap[model] = gpt4TokenEncoder
- } else {
- tokenEncoderMap[model] = nil
- }
- }
- common.SysLog("token encoders initialized")
- }
- func getTokenEncoder(model string) *tiktoken.Tiktoken {
- tokenEncoder, ok := tokenEncoderMap[model]
- if ok && tokenEncoder != nil {
- return tokenEncoder
- }
- if ok {
- tokenEncoder, err := tiktoken.EncodingForModel(model)
- if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
- tokenEncoder = defaultTokenEncoder
- }
- tokenEncoderMap[model] = tokenEncoder
- return tokenEncoder
- }
- return defaultTokenEncoder
- }
- func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
- return len(tokenEncoder.Encode(text, nil, nil))
- }
- func countTokenMessages(messages []Message, model string) (int, error) {
- //recover when panic
- tokenEncoder := getTokenEncoder(model)
- // Reference:
- // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- // https://github.com/pkoukk/tiktoken-go/issues/6
- //
- // Every message follows <|start|>{role/name}\n{content}<|end|>\n
- var tokensPerMessage int
- var tokensPerName int
- if model == "gpt-3.5-turbo-0301" {
- tokensPerMessage = 4
- tokensPerName = -1 // If there's a name, the role is omitted
- } else {
- tokensPerMessage = 3
- tokensPerName = 1
- }
- tokenNum := 0
- for _, message := range messages {
- tokenNum += tokensPerMessage
- tokenNum += getTokenNum(tokenEncoder, message.Role)
- var arrayContent []MediaMessage
- if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
- var stringContent string
- if err := json.Unmarshal(message.Content, &stringContent); err != nil {
- return 0, err
- } else {
- tokenNum += getTokenNum(tokenEncoder, stringContent)
- if message.Name != nil {
- tokenNum += tokensPerName
- tokenNum += getTokenNum(tokenEncoder, *message.Name)
- }
- }
- } else {
- for _, m := range arrayContent {
- if m.Type == "image_url" {
- //TODO: getImageToken
- tokenNum += 1000
- } else {
- tokenNum += getTokenNum(tokenEncoder, m.Text)
- }
- }
- }
- }
- tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
- return tokenNum, nil
- }
- func countTokenInput(input any, model string) int {
- switch input.(type) {
- case string:
- return countTokenText(input.(string), model)
- case []string:
- text := ""
- for _, s := range input.([]string) {
- text += s
- }
- return countTokenText(text, model)
- }
- return 0
- }
- func countAudioToken(text string, model string) int {
- if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text)
- } else {
- return countTokenText(text, model)
- }
- }
- func countTokenText(text string, model string) int {
- tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text)
- }
- func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
- text := err.Error()
- // 定义一个正则表达式匹配URL
- if strings.Contains(text, "Post") {
- text = "请求上游地址失败"
- }
- //避免暴露内部错误
- openAIError := OpenAIError{
- Message: text,
- Type: "one_api_error",
- Code: code,
- }
- return &OpenAIErrorWithStatusCode{
- OpenAIError: openAIError,
- StatusCode: statusCode,
- }
- }
- func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
- if !common.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if statusCode == http.StatusUnauthorized {
- return true
- }
- if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
- return true
- }
- return false
- }
- func setEventStreamHeaders(c *gin.Context) {
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("Transfer-Encoding", "chunked")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- }
- func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
- openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
- StatusCode: resp.StatusCode,
- OpenAIError: OpenAIError{
- Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
- Type: "upstream_error",
- Code: "bad_response_status_code",
- Param: strconv.Itoa(resp.StatusCode),
- },
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return
- }
- err = resp.Body.Close()
- if err != nil {
- return
- }
- var textResponse TextResponse
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return
- }
- openAIErrorWithStatusCode.OpenAIError = textResponse.Error
- return
- }
- func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- if channelType == common.ChannelTypeOpenAI {
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- }
- }
- return fullRequestURL
- }
|