|
|
@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
}
|
|
|
|
|
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
+ if text == "" {
|
|
|
+ return 0
|
|
|
+ }
|
|
|
return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
}
|
|
|
|
|
|
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|
|
tokenNum += tokensPerMessage
|
|
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
|
if len(message.Content) > 0 {
|
|
|
- if message.IsStringContent() {
|
|
|
- stringContent := message.StringContent()
|
|
|
- tokenNum += getTokenNum(tokenEncoder, stringContent)
|
|
|
- if message.Name != nil {
|
|
|
- tokenNum += tokensPerName
|
|
|
- tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
|
- }
|
|
|
- } else {
|
|
|
- arrayContent := message.ParseContent()
|
|
|
- for _, m := range arrayContent {
|
|
|
- if m.Type == dto.ContentTypeImageURL {
|
|
|
- imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
- imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- tokenNum += imageTokenNum
|
|
|
- log.Printf("image token num: %d", imageTokenNum)
|
|
|
- } else if m.Type == dto.ContentTypeInputAudio {
|
|
|
- // TODO: 音频token数量计算
|
|
|
- tokenNum += 100
|
|
|
- } else {
|
|
|
- tokenNum += getTokenNum(tokenEncoder, m.Text)
|
|
|
+ if message.Name != nil {
|
|
|
+ tokenNum += tokensPerName
|
|
|
+ tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
|
+ }
|
|
|
+ arrayContent := message.ParseContent()
|
|
|
+ for _, m := range arrayContent {
|
|
|
+ if m.Type == dto.ContentTypeImageURL {
|
|
|
+ imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
+ imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
}
|
|
|
+ tokenNum += imageTokenNum
|
|
|
+ log.Printf("image token num: %d", imageTokenNum)
|
|
|
+ } else if m.Type == dto.ContentTypeInputAudio {
|
|
|
+ // TODO: 音频token数量计算
|
|
|
+ tokenNum += 100
|
|
|
+ } else {
|
|
|
+ tokenNum += getTokenNum(tokenEncoder, m.Text)
|
|
|
}
|
|
|
}
|
|
|
}
|