ソースを参照

refactor: Improve quota calculation precision using floating-point arithmetic

[email protected] 9 ヶ月 前
コミット
bb848b2fe0
2 ファイル変更16 行追加14 行削除
  1. 9 8
      relay/relay-text.go
  2. 7 6
      service/quota.go

+ 9 - 8
relay/relay-text.go

@@ -320,19 +320,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	groupRatio := priceData.GroupRatio
 	modelPrice := priceData.ModelPrice
 
-	quota := 0
+	quotaCalculate := 0.0
 	if !priceData.UsePrice {
-		quota = (promptTokens - cacheTokens) + int(math.Round(float64(cacheTokens)*cacheRatio))
-		quota += int(math.Round(float64(completionTokens) * completionRatio))
-		quota = int(math.Round(float64(quota) * ratio))
-		if ratio != 0 && quota <= 0 {
-			quota = 1
+		quotaCalculate = float64(promptTokens-cacheTokens) + float64(cacheTokens)*cacheRatio
+		quotaCalculate += float64(completionTokens) * completionRatio
+		quotaCalculate = quotaCalculate * ratio
+		if ratio != 0 && quotaCalculate <= 0 {
+			quotaCalculate = 1
 		}
 	} else {
-		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+		quotaCalculate = modelPrice * common.QuotaPerUnit * groupRatio
 	}
+	quota := int(quotaCalculate)
 	totalTokens := promptTokens + completionTokens
-	
+
 	var logContent string
 	if !priceData.UsePrice {
 		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)

+ 7 - 6
service/quota.go

@@ -4,7 +4,6 @@ import (
 	"errors"
 	"fmt"
 	"github.com/bytedance/gopkg/util/gopool"
-	"math"
 	"one-api/common"
 	constant2 "one-api/constant"
 	"one-api/dto"
@@ -44,16 +43,18 @@ func calculateAudioQuota(info QuotaInfo) int {
 	audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
 	ratio := info.GroupRatio * info.ModelRatio
 
-	quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
-	quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
-		int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
+	quota := 0.0
+	quota += float64(info.InputDetails.TextTokens)
+	quota += float64(info.OutputDetails.TextTokens) * completionRatio
+	quota += float64(info.InputDetails.AudioTokens) * audioRatio
+	quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
 
-	quota = int(math.Round(float64(quota) * ratio))
+	quota = quota * ratio
 	if ratio != 0 && quota <= 0 {
 		quota = 1
 	}
 
-	return quota
+	return int(quota)
 }
 
 func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {