Просмотр исходного кода

fix: prompt calculation

User will correctly get estimated prompt usage when upstream returns either zero or nothing.

funnycups 4 месяцев назад
Родитель
Сommit
e3473e3c39
2 измененных файлов с 12 добавлено и 6 удалено
  1. 2 0
      controller/relay.go
  2. 10 6
      relay/channel/openai/relay-openai.go

+ 2 - 0
controller/relay.go

@@ -128,6 +128,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		return
 	}
 
+	relayInfo.SetPromptTokens(tokens)
+
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
 	if err != nil {
 		newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)

+ 10 - 6
relay/channel/openai/relay-openai.go

@@ -197,22 +197,26 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		forceFormat = true
 	}
 
-	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
-		completionTokens := 0
-		for _, choice := range simpleResponse.Choices {
-			ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
-			completionTokens += ctkm
+	usageModified := false
+	if simpleResponse.Usage.PromptTokens == 0 {
+		completionTokens := simpleResponse.Usage.CompletionTokens
+		if completionTokens == 0 {
+			for _, choice := range simpleResponse.Choices {
+				ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+				completionTokens += ctkm
+			}
 		}
 		simpleResponse.Usage = dto.Usage{
 			PromptTokens:     info.PromptTokens,
 			CompletionTokens: completionTokens,
 			TotalTokens:      info.PromptTokens + completionTokens,
 		}
+		usageModified = true
 	}
 
 	switch info.RelayFormat {
 	case types.RelayFormatOpenAI:
-		if forceFormat {
+		if forceFormat || usageModified {
 			responseBody, err = common.Marshal(simpleResponse)
 			if err != nil {
 				return nil, types.NewError(err, types.ErrorCodeBadResponseBody)