Просмотр исходного кода

fix(openai): account cached tokens for
zhipu_v4 usage

RedwindA 2 месяцев назад
Родитель
Сommit
f930cdbb51
2 измененных файлов с 61 добавлено и 6 удалено
  1. 60 6
      relay/channel/openai/relay-openai.go
  2. 1 0
      relay/common/relay_info.go

+ 60 - 6
relay/channel/openai/relay-openai.go

@@ -163,13 +163,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	if !containStreamUsage {
 		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
-	} else {
-		if info.ChannelType == constant.ChannelTypeDeepSeek {
-			if usage.PromptCacheHitTokens != 0 {
-				usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
-			}
-		}
 	}
+
+	applyUsagePostProcessing(info, usage, nil)
+
 	HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
 
 	return usage, nil
@@ -233,6 +230,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		usageModified = true
 	}
 
+	applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
+
 	switch info.RelayFormat {
 	case types.RelayFormatOpenAI:
 		if usageModified {
@@ -631,5 +630,60 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 		usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
 		usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
 	}
+	applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
 	return &usageResp.Usage, nil
 }
+
+func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
+	if info == nil || usage == nil {
+		return
+	}
+
+	switch info.ChannelType {
+	case constant.ChannelTypeDeepSeek:
+		if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
+			usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
+		}
+	case constant.ChannelTypeZhipu_v4:
+		if usage.PromptTokensDetails.CachedTokens == 0 {
+			if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
+				usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
+			} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
+				usage.PromptTokensDetails.CachedTokens = cachedTokens
+			} else if usage.PromptCacheHitTokens > 0 {
+				usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
+			}
+		}
+	}
+}
+
+func extractCachedTokensFromBody(body []byte) (int, bool) {
+	if len(body) == 0 {
+		return 0, false
+	}
+
+	var payload struct {
+		Usage struct {
+			PromptTokensDetails struct {
+				CachedTokens *int `json:"cached_tokens"`
+			} `json:"prompt_tokens_details"`
+			CachedTokens         *int `json:"cached_tokens"`
+			PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
+		} `json:"usage"`
+	}
+
+	if err := json.Unmarshal(body, &payload); err != nil {
+		return 0, false
+	}
+
+	if payload.Usage.PromptTokensDetails.CachedTokens != nil {
+		return *payload.Usage.PromptTokensDetails.CachedTokens, true
+	}
+	if payload.Usage.CachedTokens != nil {
+		return *payload.Usage.CachedTokens, true
+	}
+	if payload.Usage.PromptCacheHitTokens != nil {
+		return *payload.Usage.PromptCacheHitTokens, true
+	}
+	return 0, false
+}

+ 1 - 0
relay/common/relay_info.go

@@ -261,6 +261,7 @@ var streamSupportedChannels = map[int]bool{
 	constant.ChannelTypeXai:        true,
 	constant.ChannelTypeDeepSeek:   true,
 	constant.ChannelTypeBaiduV2:    true,
+	constant.ChannelTypeZhipu_v4:   true,
 }
 
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {