|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"one-api/common"
|
|
|
"one-api/constant"
|
|
|
"one-api/dto"
|
|
|
+ relaycommon "one-api/relay/common"
|
|
|
"strings"
|
|
|
"unicode/utf8"
|
|
|
)
|
|
|
@@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
|
|
return tkm, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
|
|
|
- tkm := 0
|
|
|
- ratio := 1
|
|
|
- if request.Session != nil {
|
|
|
- msgTokens, err := CountTokenText(request.Session.Instructions, model)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- ratio = len(request.Session.Modalities)
|
|
|
- tkm += msgTokens
|
|
|
- if request.Session.Tools != nil {
|
|
|
- toolsData, _ := json.Marshal(request.Session.Tools)
|
|
|
- var openaiTools []dto.OpenAITools
|
|
|
- err := json.Unmarshal(toolsData, &openaiTools)
|
|
|
+func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
|
|
+ audioToken := 0
|
|
|
+ textToken := 0
|
|
|
+ switch request.Type {
|
|
|
+ case dto.RealtimeEventTypeSessionUpdate:
|
|
|
+ if request.Session != nil {
|
|
|
+ msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
|
|
if err != nil {
|
|
|
- return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
|
|
+ return 0, 0, err
|
|
|
}
|
|
|
- countStr := ""
|
|
|
- for _, tool := range openaiTools {
|
|
|
- countStr = tool.Function.Name
|
|
|
- if tool.Function.Description != "" {
|
|
|
- countStr += tool.Function.Description
|
|
|
- }
|
|
|
- if tool.Function.Parameters != nil {
|
|
|
- countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
|
+ textToken += msgTokens
|
|
|
+ }
|
|
|
+ case dto.RealtimeEventResponseAudioDelta:
|
|
|
+ // count audio token
|
|
|
+ atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
|
|
+ }
|
|
|
+ audioToken += atk
|
|
|
+ case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
|
|
+ // count text token
|
|
|
+ tkm, err := CountTextToken(request.Delta, model)
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
|
|
+ }
|
|
|
+ textToken += tkm
|
|
|
+ case dto.RealtimeEventInputAudioBufferAppend:
|
|
|
+ // count audio token
|
|
|
+ atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
|
|
+ }
|
|
|
+ audioToken += atk
|
|
|
+ case dto.RealtimeEventTypeResponseDone:
|
|
|
+ // count tools token
|
|
|
+ if !info.IsFirstRequest {
|
|
|
+ if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
|
|
+ for _, tool := range info.RealtimeTools {
|
|
|
+ toolTokens, err := CountTokenInput(tool, model)
|
|
|
+ if err != nil {
|
|
|
+ return 0, 0, err
|
|
|
+ }
|
|
|
+ textToken += 8
|
|
|
+ textToken += toolTokens
|
|
|
}
|
|
|
}
|
|
|
- toolTokens, err := CountTokenInput(countStr, model)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- tkm += 8
|
|
|
- tkm += toolTokens
|
|
|
}
|
|
|
}
|
|
|
- tkm *= ratio
|
|
|
- return tkm, nil
|
|
|
+ return textToken, audioToken, nil
|
|
|
}
|
|
|
|
|
|
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
|
|
@@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
|
|
func CountTokenInput(input any, model string) (int, error) {
|
|
|
switch v := input.(type) {
|
|
|
case string:
|
|
|
- return CountTokenText(v, model)
|
|
|
+ return CountTextToken(v, model)
|
|
|
case []string:
|
|
|
text := ""
|
|
|
for _, s := range v {
|
|
|
text += s
|
|
|
}
|
|
|
- return CountTokenText(text, model)
|
|
|
+ return CountTextToken(text, model)
|
|
|
}
|
|
|
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
|
|
}
|
|
|
@@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|
|
return tokens
|
|
|
}
|
|
|
|
|
|
-func CountAudioToken(text string, model string) (int, error) {
|
|
|
+func CountTTSToken(text string, model string) (int, error) {
|
|
|
if strings.HasPrefix(model, "tts") {
|
|
|
return utf8.RuneCountInString(text), nil
|
|
|
} else {
|
|
|
- return CountTokenText(text, model)
|
|
|
+ return CountTextToken(text, model)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
|
|
+ if audioBase64 == "" {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+ duration, err := parseAudio(audioBase64, audioFormat)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
}
|
|
|
+ return int(duration / 60 * 100 / 0.06), nil
|
|
|
}
|
|
|
|
|
|
-// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
-func CountTokenText(text string, model string) (int, error) {
|
|
|
+func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
|
|
+ if audioBase64 == "" {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+ duration, err := parseAudio(audioBase64, audioFormat)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ return int(duration / 60 * 200 / 0.24), nil
|
|
|
+}
|
|
|
+
|
|
|
+//func CountAudioToken(sec float64, audioType string) {
|
|
|
+// if audioType == "input" {
|
|
|
+//
|
|
|
+// }
|
|
|
+//}
|
|
|
+
|
|
|
+// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
+func CountTextToken(text string, model string) (int, error) {
|
|
|
var err error
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
return getTokenNum(tokenEncoder, text), err
|