|
|
@@ -3,7 +3,6 @@ package service
|
|
|
import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "github.com/bytedance/gopkg/util/gopool"
|
|
|
"one-api/common"
|
|
|
constant2 "one-api/constant"
|
|
|
"one-api/dto"
|
|
|
@@ -15,7 +14,10 @@ import (
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
+ "github.com/bytedance/gopkg/util/gopool"
|
|
|
+
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "github.com/shopspring/decimal"
|
|
|
)
|
|
|
|
|
|
type TokenDetails struct {
|
|
|
@@ -35,26 +37,41 @@ type QuotaInfo struct {
|
|
|
|
|
|
func calculateAudioQuota(info QuotaInfo) int {
|
|
|
if info.UsePrice {
|
|
|
- return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
|
|
|
+ modelPrice := decimal.NewFromFloat(info.ModelPrice)
|
|
|
+ quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
|
|
+ groupRatio := decimal.NewFromFloat(info.GroupRatio)
|
|
|
+
|
|
|
+ quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
|
|
|
+ return int(quota.IntPart())
|
|
|
}
|
|
|
|
|
|
- completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
|
|
|
- audioRatio := operation_setting.GetAudioRatio(info.ModelName)
|
|
|
- audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
|
|
|
- ratio := info.GroupRatio * info.ModelRatio
|
|
|
+ completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
|
|
|
+ audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
|
|
|
+ audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
|
|
|
+
|
|
|
+ groupRatio := decimal.NewFromFloat(info.GroupRatio)
|
|
|
+ modelRatio := decimal.NewFromFloat(info.ModelRatio)
|
|
|
+ ratio := groupRatio.Mul(modelRatio)
|
|
|
+
|
|
|
+ inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
|
|
|
+ outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
|
|
|
+ inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
|
|
|
+ outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
|
|
|
+
|
|
|
+ quota := decimal.Zero
|
|
|
+ quota = quota.Add(inputTextTokens)
|
|
|
+ quota = quota.Add(outputTextTokens.Mul(completionRatio))
|
|
|
+ quota = quota.Add(inputAudioTokens.Mul(audioRatio))
|
|
|
+ quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
|
|
|
|
|
|
- quota := 0.0
|
|
|
- quota += float64(info.InputDetails.TextTokens)
|
|
|
- quota += float64(info.OutputDetails.TextTokens) * completionRatio
|
|
|
- quota += float64(info.InputDetails.AudioTokens) * audioRatio
|
|
|
- quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
|
|
|
+ quota = quota.Mul(ratio)
|
|
|
|
|
|
- quota = quota * ratio
|
|
|
- if ratio != 0 && quota <= 0 {
|
|
|
- quota = 1
|
|
|
+ // If ratio is not zero and quota is less than or equal to zero, set quota to 1
|
|
|
+ if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
|
|
|
+ quota = decimal.NewFromInt(1)
|
|
|
}
|
|
|
|
|
|
- return int(quota)
|
|
|
+ return int(quota.Round(0).IntPart())
|
|
|
}
|
|
|
|
|
|
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
|
|
|
@@ -124,9 +141,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|
|
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
|
|
|
|
|
tokenName := ctx.GetString("token_name")
|
|
|
- completionRatio := operation_setting.GetCompletionRatio(modelName)
|
|
|
- audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
|
|
|
- audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)
|
|
|
+ completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
|
|
|
+ audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
|
|
+ audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
|
|
|
|
|
|
quotaInfo := QuotaInfo{
|
|
|
InputDetails: TokenDetails{
|
|
|
@@ -148,7 +165,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|
|
totalTokens := usage.TotalTokens
|
|
|
var logContent string
|
|
|
if !usePrice {
|
|
|
- logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
|
|
|
+ logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
|
|
|
+ modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
|
|
|
} else {
|
|
|
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
|
|
}
|
|
|
@@ -170,7 +188,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|
|
if extraContent != "" {
|
|
|
logContent += ", " + extraContent
|
|
|
}
|
|
|
- other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
|
|
+ other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
|
|
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
|
|
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
|
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
|
|
}
|
|
|
@@ -186,9 +205,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
|
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
|
|
|
|
|
tokenName := ctx.GetString("token_name")
|
|
|
- completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
|
|
|
- audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
|
|
|
- audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
|
|
|
+ completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
|
|
+ audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
|
|
+ audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
|
|
|
|
|
modelRatio := priceData.ModelRatio
|
|
|
groupRatio := priceData.GroupRatio
|
|
|
@@ -215,7 +234,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
|
totalTokens := usage.TotalTokens
|
|
|
var logContent string
|
|
|
if !usePrice {
|
|
|
- logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
|
|
|
+ logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
|
|
|
+ modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
|
|
|
} else {
|
|
|
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
|
|
}
|
|
|
@@ -244,7 +264,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
|
if extraContent != "" {
|
|
|
logContent += ", " + extraContent
|
|
|
}
|
|
|
- other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
|
|
+ other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
|
|
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
|
|
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
|
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
|
|
}
|