|
|
@@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
}
|
|
|
|
|
|
-func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
|
|
+func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
|
|
baseTokens := 85
|
|
|
if model == "glm-4v" {
|
|
|
return 1047, nil
|
|
|
@@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
|
|
if !constant.GetMediaToken {
|
|
|
return 256, nil
|
|
|
}
|
|
|
+ if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
|
|
+ return 256, nil
|
|
|
+ }
|
|
|
// 同步One API的图片计费逻辑
|
|
|
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
|
|
imageUrl.Detail = "high"
|
|
|
@@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
|
|
return tiles*tileTokens + baseTokens, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
|
|
+func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
|
|
tkm := 0
|
|
|
- msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
|
|
+ msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
|
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
|
}
|
|
|
}
|
|
|
- toolTokens, err := CountTokenInput(countStr, model)
|
|
|
+ toolTokens, err := CountTokenInput(countStr, request.Model)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|
|
return textToken, audioToken, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
|
|
+func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
|
|
//recover when panic
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
// Reference:
|
|
|
@@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
|
|
for _, m := range arrayContent {
|
|
|
if m.Type == dto.ContentTypeImageURL {
|
|
|
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
- imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
|
|
+ imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|