Răsfoiți Sursa

feat: Improve decimal precision for quota and payment calculations

- Added github.com/shopspring/decimal for precise floating-point calculations
- Refactored quota and payment calculations in multiple files to use decimal arithmetic
- Updated go.mod and go.sum to include decimal library
- Improved precision in topup, relay, and quota service calculations
- Added support for more OpenAI model variants in cache ratio settings
[email protected] 10 luni în urmă
părinte
comite
68097c132d
6 a modificat fișierele cu 111 adăugiri și 56 ștergeri
  1. 28 11
      controller/topup.go
  2. 1 1
      go.mod
  3. 2 2
      go.sum
  4. 29 17
      relay/relay-text.go
  5. 46 25
      service/quota.go
  6. 5 0
      setting/operation_setting/cache_ratio.go

+ 28 - 11
controller/topup.go

@@ -2,9 +2,6 @@ package controller
 
 import (
 	"fmt"
-	"github.com/Calcium-Ion/go-epay/epay"
-	"github.com/gin-gonic/gin"
-	"github.com/samber/lo"
 	"log"
 	"net/url"
 	"one-api/common"
@@ -14,6 +11,11 @@ import (
 	"strconv"
 	"sync"
 	"time"
+
+	"github.com/Calcium-Ion/go-epay/epay"
+	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
+	"github.com/shopspring/decimal"
 )
 
 type EpayRequest struct {
@@ -42,22 +44,32 @@ func GetEpayClient() *epay.Client {
 }
 
 func getPayMoney(amount float64, group string) float64 {
+	dAmount := decimal.NewFromFloat(amount)
+
 	if !common.DisplayInCurrencyEnabled {
-		amount = amount / common.QuotaPerUnit
+		dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+		dAmount = dAmount.Div(dQuotaPerUnit)
 	}
-	// 别问为什么用float64,问就是这么点钱没必要
+
 	topupGroupRatio := common.GetTopupGroupRatio(group)
 	if topupGroupRatio == 0 {
 		topupGroupRatio = 1
 	}
-	payMoney := amount * setting.Price * topupGroupRatio
-	return payMoney
+
+	dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
+	dPrice := decimal.NewFromFloat(setting.Price)
+
+	payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
+
+	return payMoney.InexactFloat64()
 }
 
 func getMinTopup() int {
 	minTopup := setting.MinTopUp
 	if !common.DisplayInCurrencyEnabled {
-		minTopup = minTopup * int(common.QuotaPerUnit)
+		dMinTopup := decimal.NewFromInt(int64(minTopup))
+		dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+		minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
 	}
 	return minTopup
 }
@@ -118,7 +130,9 @@ func RequestEpay(c *gin.Context) {
 	}
 	amount := req.Amount
 	if !common.DisplayInCurrencyEnabled {
-		amount = amount / int(common.QuotaPerUnit)
+		dAmount := decimal.NewFromInt(int64(amount))
+		dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+		amount = int(dAmount.Div(dQuotaPerUnit).IntPart())
 	}
 	topUp := &model.TopUp{
 		UserId:     id,
@@ -210,13 +224,16 @@ func EpayNotify(c *gin.Context) {
 			}
 			//user, _ := model.GetUserById(topUp.UserId, false)
 			//user.Quota += topUp.Amount * 500000
-			err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
+			dAmount := decimal.NewFromInt(int64(topUp.Amount))
+			dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+			quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
+			err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
 			if err != nil {
 				log.Printf("易支付回调更新用户失败: %v", topUp)
 				return
 			}
 			log.Printf("易支付回调更新用户成功 %v", topUp)
-			model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
+			model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
 		}
 	} else {
 		log.Printf("易支付异常回调: %v", verifyInfo)

+ 1 - 1
go.mod

@@ -22,12 +22,12 @@ require (
 	github.com/golang-jwt/jwt v3.2.2+incompatible
 	github.com/google/uuid v1.6.0
 	github.com/gorilla/websocket v1.5.0
-	github.com/jinzhu/copier v0.4.0
 	github.com/joho/godotenv v1.5.1
 	github.com/pkg/errors v0.9.1
 	github.com/pkoukk/tiktoken-go v0.1.7
 	github.com/samber/lo v1.39.0
 	github.com/shirou/gopsutil v3.21.11+incompatible
+	github.com/shopspring/decimal v1.4.0
 	golang.org/x/crypto v0.27.0
 	golang.org/x/image v0.23.0
 	golang.org/x/net v0.28.0

+ 2 - 2
go.sum

@@ -117,8 +117,6 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
 github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
 github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
 github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
-github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
-github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -183,6 +181,8 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
 github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
 github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
+github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
+github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

+ 29 - 17
relay/relay-text.go

@@ -5,7 +5,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"math"
 	"net/http"
@@ -21,6 +20,9 @@ import (
 	"strings"
 	"time"
 
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/shopspring/decimal"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -315,23 +317,40 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	tokenName := ctx.GetString("token_name")
 	completionRatio := priceData.CompletionRatio
 	cacheRatio := priceData.CacheRatio
-	ratio := priceData.ModelRatio * priceData.GroupRatio
 	modelRatio := priceData.ModelRatio
 	groupRatio := priceData.GroupRatio
 	modelPrice := priceData.ModelPrice
 
-	quotaCalculate := 0.0
+	// Convert values to decimal for precise calculation
+	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
+	dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
+	dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+	dCompletionRatio := decimal.NewFromFloat(completionRatio)
+	dCacheRatio := decimal.NewFromFloat(cacheRatio)
+	dModelRatio := decimal.NewFromFloat(modelRatio)
+	dGroupRatio := decimal.NewFromFloat(groupRatio)
+	dModelPrice := decimal.NewFromFloat(modelPrice)
+	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+
+	ratio := dModelRatio.Mul(dGroupRatio)
+
+	var quotaCalculateDecimal decimal.Decimal
 	if !priceData.UsePrice {
-		quotaCalculate = float64(promptTokens-cacheTokens) + float64(cacheTokens)*cacheRatio
-		quotaCalculate += float64(completionTokens) * completionRatio
-		quotaCalculate = quotaCalculate * ratio
-		if ratio != 0 && quotaCalculate <= 0 {
-			quotaCalculate = 1
+		nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
+		cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
+		promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
+		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
+
+		quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
+
+		if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
+			quotaCalculateDecimal = decimal.NewFromInt(1)
 		}
 	} else {
-		quotaCalculate = modelPrice * common.QuotaPerUnit * groupRatio
+		quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
 	}
-	quota := int(quotaCalculate)
+
+	quota := int(quotaCalculateDecimal.Round(0).IntPart())
 	totalTokens := promptTokens + completionTokens
 
 	var logContent string
@@ -350,9 +369,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 	} else {
-		//if sensitiveResp != nil {
-		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
-		//}
 		quotaDelta := quota - preConsumedQuota
 		if quotaDelta != 0 {
 			err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
@@ -379,8 +395,4 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
-
-	//if quota != 0 {
-	//
-	//}
 }

+ 46 - 25
service/quota.go

@@ -3,7 +3,6 @@ package service
 import (
 	"errors"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
 	"one-api/common"
 	constant2 "one-api/constant"
 	"one-api/dto"
@@ -15,7 +14,10 @@ import (
 	"strings"
 	"time"
 
+	"github.com/bytedance/gopkg/util/gopool"
+
 	"github.com/gin-gonic/gin"
+	"github.com/shopspring/decimal"
 )
 
 type TokenDetails struct {
@@ -35,26 +37,41 @@ type QuotaInfo struct {
 
 func calculateAudioQuota(info QuotaInfo) int {
 	if info.UsePrice {
-		return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
+		modelPrice := decimal.NewFromFloat(info.ModelPrice)
+		quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+		groupRatio := decimal.NewFromFloat(info.GroupRatio)
+
+		quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
+		return int(quota.IntPart())
 	}
 
-	completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
-	audioRatio := operation_setting.GetAudioRatio(info.ModelName)
-	audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
-	ratio := info.GroupRatio * info.ModelRatio
+	completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
+	audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
+	audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
+
+	groupRatio := decimal.NewFromFloat(info.GroupRatio)
+	modelRatio := decimal.NewFromFloat(info.ModelRatio)
+	ratio := groupRatio.Mul(modelRatio)
+
+	inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
+	outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
+	inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
+	outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
+
+	quota := decimal.Zero
+	quota = quota.Add(inputTextTokens)
+	quota = quota.Add(outputTextTokens.Mul(completionRatio))
+	quota = quota.Add(inputAudioTokens.Mul(audioRatio))
+	quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
 
-	quota := 0.0
-	quota += float64(info.InputDetails.TextTokens)
-	quota += float64(info.OutputDetails.TextTokens) * completionRatio
-	quota += float64(info.InputDetails.AudioTokens) * audioRatio
-	quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
+	quota = quota.Mul(ratio)
 
-	quota = quota * ratio
-	if ratio != 0 && quota <= 0 {
-		quota = 1
+	// If ratio is not zero and quota is less than or equal to zero, set quota to 1
+	if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
+		quota = decimal.NewFromInt(1)
 	}
 
-	return int(quota)
+	return int(quota.Round(0).IntPart())
 }
 
 func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
@@ -124,9 +141,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	audioOutTokens := usage.OutputTokenDetails.AudioTokens
 
 	tokenName := ctx.GetString("token_name")
-	completionRatio := operation_setting.GetCompletionRatio(modelName)
-	audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
-	audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)
+	completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
+	audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
+	audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
 
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
@@ -148,7 +165,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	totalTokens := usage.TotalTokens
 	var logContent string
 	if !usePrice {
-		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
+		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
+			modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
 	} else {
 		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
 	}
@@ -170,7 +188,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	if extraContent != "" {
 		logContent += ", " + extraContent
 	}
-	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
+	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }
@@ -186,9 +205,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	audioOutTokens := usage.CompletionTokenDetails.AudioTokens
 
 	tokenName := ctx.GetString("token_name")
-	completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
-	audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
-	audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
+	completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
+	audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
+	audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
 
 	modelRatio := priceData.ModelRatio
 	groupRatio := priceData.GroupRatio
@@ -215,7 +234,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	totalTokens := usage.TotalTokens
 	var logContent string
 	if !usePrice {
-		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
+		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
+			modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
 	} else {
 		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
 	}
@@ -244,7 +264,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	if extraContent != "" {
 		logContent += ", " + extraContent
 	}
-	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
+	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }

+ 5 - 0
setting/operation_setting/cache_ratio.go

@@ -8,12 +8,17 @@ import (
 
 var defaultCacheRatio = map[string]float64{
 	"gpt-4":                        0.5,
+	"o1":                           0.5,
 	"o1-2024-12-17":                0.5,
 	"o1-preview-2024-09-12":        0.5,
+	"o1-preview":                   0.5,
 	"o1-mini-2024-09-12":           0.5,
+	"o1-mini":                      0.5,
 	"gpt-4o-2024-11-20":            0.5,
 	"gpt-4o-2024-08-06":            0.5,
+	"gpt-4o":                       0.5,
 	"gpt-4o-mini-2024-07-18":       0.5,
+	"gpt-4o-mini":                  0.5,
 	"gpt-4o-realtime-preview":      0.5,
 	"gpt-4o-mini-realtime-preview": 0.5,
 	"deepseek-chat":                0.1,