|
@@ -163,13 +163,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|
|
if !containStreamUsage {
|
|
if !containStreamUsage {
|
|
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
|
usage.CompletionTokens += toolCount * 7
|
|
usage.CompletionTokens += toolCount * 7
|
|
|
- } else {
|
|
|
|
|
- if info.ChannelType == constant.ChannelTypeDeepSeek {
|
|
|
|
|
- if usage.PromptCacheHitTokens != 0 {
|
|
|
|
|
- usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ applyUsagePostProcessing(info, usage, nil)
|
|
|
|
|
+
|
|
|
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
|
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
|
|
|
|
|
|
|
return usage, nil
|
|
return usage, nil
|
|
@@ -233,6 +230,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|
|
usageModified = true
|
|
usageModified = true
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
|
|
|
|
|
+
|
|
|
switch info.RelayFormat {
|
|
switch info.RelayFormat {
|
|
|
case types.RelayFormatOpenAI:
|
|
case types.RelayFormatOpenAI:
|
|
|
if usageModified {
|
|
if usageModified {
|
|
@@ -631,5 +630,60 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
|
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
|
|
}
|
|
}
|
|
|
|
|
+ applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
|
|
return &usageResp.Usage, nil
|
|
return &usageResp.Usage, nil
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
|
|
|
|
+ if info == nil || usage == nil {
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ switch info.ChannelType {
|
|
|
|
|
+ case constant.ChannelTypeDeepSeek:
|
|
|
|
|
+ if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
|
|
|
|
+ usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
|
|
|
+ }
|
|
|
|
|
+ case constant.ChannelTypeZhipu_v4:
|
|
|
|
|
+ if usage.PromptTokensDetails.CachedTokens == 0 {
|
|
|
|
|
+ if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
|
|
|
|
+ usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
|
|
|
|
+ } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
|
|
|
|
+ usage.PromptTokensDetails.CachedTokens = cachedTokens
|
|
|
|
|
+ } else if usage.PromptCacheHitTokens > 0 {
|
|
|
|
|
+ usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func extractCachedTokensFromBody(body []byte) (int, bool) {
|
|
|
|
|
+ if len(body) == 0 {
|
|
|
|
|
+ return 0, false
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ var payload struct {
|
|
|
|
|
+ Usage struct {
|
|
|
|
|
+ PromptTokensDetails struct {
|
|
|
|
|
+ CachedTokens *int `json:"cached_tokens"`
|
|
|
|
|
+ } `json:"prompt_tokens_details"`
|
|
|
|
|
+ CachedTokens *int `json:"cached_tokens"`
|
|
|
|
|
+ PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
|
|
|
|
+ } `json:"usage"`
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if err := json.Unmarshal(body, &payload); err != nil {
|
|
|
|
|
+ return 0, false
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
|
|
|
|
+ return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
|
|
|
|
+ }
|
|
|
|
|
+ if payload.Usage.CachedTokens != nil {
|
|
|
|
|
+ return *payload.Usage.CachedTokens, true
|
|
|
|
|
+ }
|
|
|
|
|
+ if payload.Usage.PromptCacheHitTokens != nil {
|
|
|
|
|
+ return *payload.Usage.PromptCacheHitTokens, true
|
|
|
|
|
+ }
|
|
|
|
|
+ return 0, false
|
|
|
|
|
+}
|