Selaa lähdekoodia

perf: lazy initialization for token encoders (close #566)

JustSong 2 vuotta sitten
vanhempi
sitoutus
594f06e7b0
1 muutettua tiedostoa jossa 25 lisäystä ja 16 poistoa
  1. 25 16
      controller/relay-utils.go

+ 25 - 16
controller/relay-utils.go

@@ -9,44 +9,53 @@ import (
 	"net/http"
 	"one-api/common"
 	"strconv"
+	"strings"
 )
 
 var stopFinishReason = "stop"
 
+// tokenEncoderMap won't grow after initialization
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
+var defaultTokenEncoder *tiktoken.Tiktoken
 
 func InitTokenEncoders() {
 	common.SysLog("initializing token encoders")
-	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 	if err != nil {
-		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
+		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
+	}
+	defaultTokenEncoder = gpt35TokenEncoder
+	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
+	if err != nil {
+		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
 	}
 	for model, _ := range common.ModelRatio {
-		tokenEncoder, err := tiktoken.EncodingForModel(model)
-		if err != nil {
-			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
-			tokenEncoderMap[model] = fallbackTokenEncoder
-			continue
+		if strings.HasPrefix(model, "gpt-3.5") {
+			tokenEncoderMap[model] = gpt35TokenEncoder
+		} else if strings.HasPrefix(model, "gpt-4") {
+			tokenEncoderMap[model] = gpt4TokenEncoder
+		} else {
+			tokenEncoderMap[model] = nil
 		}
-		tokenEncoderMap[model] = tokenEncoder
 	}
 	common.SysLog("token encoders initialized")
 }
 
 func getTokenEncoder(model string) *tiktoken.Tiktoken {
-	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
+	tokenEncoder, ok := tokenEncoderMap[model]
+	if ok && tokenEncoder != nil {
 		return tokenEncoder
 	}
-	tokenEncoder, err := tiktoken.EncodingForModel(model)
-	if err != nil {
-		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
-		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
+	if ok {
+		tokenEncoder, err := tiktoken.EncodingForModel(model)
 		if err != nil {
-			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
+			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
+			tokenEncoder = defaultTokenEncoder
 		}
+		tokenEncoderMap[model] = tokenEncoder
+		return tokenEncoder
 	}
-	tokenEncoderMap[model] = tokenEncoder
-	return tokenEncoder
+	return defaultTokenEncoder
 }
 
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {