Quellcode durchsuchen

feat: Enhance token counting and content parsing for messages

[email protected] vor 10 Monaten
Ursprung
Commit
7ff4cebdbe
3 geänderte Dateien mit 40 neuen und 26 gelöschten Zeilen
  1. 1 1
      common/constants.go
  2. 18 2
      dto/openai_request.go
  3. 21 23
      service/token_counter.go

+ 1 - 1
common/constants.go

@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
 	"https://api.cohere.ai",                     //34
 	"https://api.minimax.chat",                  //35
 	"",                                          //36
-	"",                                          //37
+	"https://api.dify.ai",                       //37
 	"https://api.jina.ai",                       //38
 	"https://api.cloudflare.com",                //39
 	"https://api.siliconflow.cn",                //40

+ 18 - 2
dto/openai_request.go

@@ -1,6 +1,9 @@
 package dto
 
-import "encoding/json"
+import (
+	"encoding/json"
+	"strings"
+)
 
 type ResponseFormat struct {
 	Type       string            `json:"type,omitempty"`
@@ -153,11 +156,24 @@ func (m *Message) StringContent() string {
 	if m.parsedStringContent != nil {
 		return *m.parsedStringContent
 	}
+
 	var stringContent string
 	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		m.parsedStringContent = &stringContent
 		return stringContent
 	}
-	return string(m.Content)
+
+	contentStr := new(strings.Builder)
+	arrayContent := m.ParseContent()
+	for _, content := range arrayContent {
+		if content.Type == ContentTypeText {
+			contentStr.WriteString(content.Text)
+		}
+	}
+	stringContent = contentStr.String()
+	m.parsedStringContent = &stringContent
+
+	return stringContent
 }
 
 func (m *Message) SetStringContent(content string) {

+ 21 - 23
service/token_counter.go

@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
 }
 
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
+	if text == "" {
+		return 0
+	}
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
 		tokenNum += tokensPerMessage
 		tokenNum += getTokenNum(tokenEncoder, message.Role)
 		if len(message.Content) > 0 {
-			if message.IsStringContent() {
-				stringContent := message.StringContent()
-				tokenNum += getTokenNum(tokenEncoder, stringContent)
-				if message.Name != nil {
-					tokenNum += tokensPerName
-					tokenNum += getTokenNum(tokenEncoder, *message.Name)
-				}
-			} else {
-				arrayContent := message.ParseContent()
-				for _, m := range arrayContent {
-					if m.Type == dto.ContentTypeImageURL {
-						imageUrl := m.ImageUrl.(dto.MessageImageUrl)
-						imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
-						if err != nil {
-							return 0, err
-						}
-						tokenNum += imageTokenNum
-						log.Printf("image token num: %d", imageTokenNum)
-					} else if m.Type == dto.ContentTypeInputAudio {
-						// TODO: 音频token数量计算
-						tokenNum += 100
-					} else {
-						tokenNum += getTokenNum(tokenEncoder, m.Text)
+			if message.Name != nil {
+				tokenNum += tokensPerName
+				tokenNum += getTokenNum(tokenEncoder, *message.Name)
+			}
+			arrayContent := message.ParseContent()
+			for _, m := range arrayContent {
+				if m.Type == dto.ContentTypeImageURL {
+					imageUrl := m.ImageUrl.(dto.MessageImageUrl)
+					imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
+					if err != nil {
+						return 0, err
 					}
+					tokenNum += imageTokenNum
+					log.Printf("image token num: %d", imageTokenNum)
+				} else if m.Type == dto.ContentTypeInputAudio {
+					// TODO: 音频token数量计算
+					tokenNum += 100
+				} else {
+					tokenNum += getTokenNum(tokenEncoder, m.Text)
 				}
 			}
 		}