Browse Source

fix: openai 格式请求 claude 没计费 create cache token

Xyfacai 3 tháng trước cách đây
mục cha
commit
fcdfd027cd

+ 1 - 1
controller/channel-test.go

@@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			err := service.RelayErrorHandler(httpResp, true)
+			err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
 			return testResult{
 			return testResult{
 				context:     c,
 				context:     c,
 				localErr:    err,
 				localErr:    err,

+ 1 - 1
relay/audio_handler.go

@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 1 - 1
relay/claude_handler.go

@@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 18 - 2
relay/compatible_handler.go

@@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newApiErr := service.RelayErrorHandler(httpResp, false)
+			newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 			service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 			return newApiErr
 			return newApiErr
@@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	imageTokens := usage.PromptTokensDetails.ImageTokens
 	imageTokens := usage.PromptTokensDetails.ImageTokens
 	audioTokens := usage.PromptTokensDetails.AudioTokens
 	audioTokens := usage.PromptTokensDetails.AudioTokens
 	completionTokens := usage.CompletionTokens
 	completionTokens := usage.CompletionTokens
+	cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
+
 	modelName := relayInfo.OriginModelName
 	modelName := relayInfo.OriginModelName
 
 
 	tokenName := ctx.GetString("token_name")
 	tokenName := ctx.GetString("token_name")
@@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	modelRatio := relayInfo.PriceData.ModelRatio
 	modelRatio := relayInfo.PriceData.ModelRatio
 	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
 	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
 	modelPrice := relayInfo.PriceData.ModelPrice
 	modelPrice := relayInfo.PriceData.ModelPrice
+	cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
 
 
 	// Convert values to decimal for precise calculation
 	// Convert values to decimal for precise calculation
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	dImageTokens := decimal.NewFromInt(int64(imageTokens))
 	dImageTokens := decimal.NewFromInt(int64(imageTokens))
 	dAudioTokens := decimal.NewFromInt(int64(audioTokens))
 	dAudioTokens := decimal.NewFromInt(int64(audioTokens))
 	dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
 	dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+	dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
 	dCompletionRatio := decimal.NewFromFloat(completionRatio)
 	dCompletionRatio := decimal.NewFromFloat(completionRatio)
 	dCacheRatio := decimal.NewFromFloat(cacheRatio)
 	dCacheRatio := decimal.NewFromFloat(cacheRatio)
 	dImageRatio := decimal.NewFromFloat(imageRatio)
 	dImageRatio := decimal.NewFromFloat(imageRatio)
 	dModelRatio := decimal.NewFromFloat(modelRatio)
 	dModelRatio := decimal.NewFromFloat(modelRatio)
 	dGroupRatio := decimal.NewFromFloat(groupRatio)
 	dGroupRatio := decimal.NewFromFloat(groupRatio)
 	dModelPrice := decimal.NewFromFloat(modelPrice)
 	dModelPrice := decimal.NewFromFloat(modelPrice)
+	dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
 	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 
 
 	ratio := dModelRatio.Mul(dGroupRatio)
 	ratio := dModelRatio.Mul(dGroupRatio)
@@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 			baseTokens = baseTokens.Sub(dCacheTokens)
 			baseTokens = baseTokens.Sub(dCacheTokens)
 			cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
 			cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
 		}
 		}
+		var dCachedCreationTokensWithRatio decimal.Decimal
+		if !dCachedCreationTokens.IsZero() {
+			baseTokens = baseTokens.Sub(dCachedCreationTokens)
+			dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
+		}
 
 
 		// 减去 image tokens
 		// 减去 image tokens
 		var imageTokensWithRatio decimal.Decimal
 		var imageTokensWithRatio decimal.Decimal
@@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 				extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
 				extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
 			}
 			}
 		}
 		}
-		promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+		promptQuota := baseTokens.Add(cachedTokensWithRatio).
+			Add(imageTokensWithRatio).
+			Add(dCachedCreationTokensWithRatio)
 
 
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 
 
@@ -395,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		other["image_ratio"] = imageRatio
 		other["image_ratio"] = imageRatio
 		other["image_output"] = imageTokens
 		other["image_output"] = imageTokens
 	}
 	}
+	if cachedCreationTokens != 0 {
+		other["cache_creation_tokens"] = cachedCreationTokens
+		other["cache_creation_ratio"] = cachedCreationRatio
+	}
 	if !dWebSearchQuota.IsZero() {
 	if !dWebSearchQuota.IsZero() {
 		if relayInfo.ResponsesUsageInfo != nil {
 		if relayInfo.ResponsesUsageInfo != nil {
 			if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
 			if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {

+ 1 - 1
relay/embedding_handler.go

@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 2 - 2
relay/gemini_handler.go

@@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError
@@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError
 		}
 		}

+ 1 - 1
relay/image_handler.go

@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 1 - 1
relay/rerank_handler.go

@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 1 - 1
relay/responses_handler.go

@@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
 
 
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 			return newAPIError

+ 4 - 2
service/error.go

@@ -1,12 +1,14 @@
 package service
 package service
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/types"
 	"one-api/types"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
 	return claudeErr
 	return claudeErr
 }
 }
 
 
-func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
+func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
 	newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
 	newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
 
 
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
 			newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
 			newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
 		} else {
 		} else {
 			if common.DebugEnabled {
 			if common.DebugEnabled {
-				println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
+				logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
 			}
 			}
 			newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
 			newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
 		}
 		}