Browse Source

refactor: realtime quota

[email protected] 1 year ago
parent
commit
99245e4c1f
3 changed files with 91 additions and 98 deletions
  1. 1 1
      relay/relay-text.go
  2. 2 45
      relay/websocket.go
  3. 88 52
      service/quota.go

+ 1 - 1
relay/relay-text.go

@@ -219,7 +219,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 
 	if strings.HasPrefix(relayInfo.UpstreamModelName, "gpt-4o-audio") {
-		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 	} else {
 		postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 	}

+ 2 - 45
relay/websocket.go

@@ -13,24 +13,6 @@ import (
 	"one-api/setting"
 )
 
-//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
-//	_, p, err := ws.ReadMessage()
-//	if err != nil {
-//		return nil, err
-//	}
-//	realtimeEvent := &dto.RealtimeEvent{}
-//	err = json.Unmarshal(p, realtimeEvent)
-//	if err != nil {
-//		return nil, err
-//	}
-//	// save the original request
-//	if realtimeEvent.Session == nil {
-//		return nil, errors.New("session object is nil")
-//	}
-//	c.Set("first_wss_request", p)
-//	return realtimeEvent, nil
-//}
-
 func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 
@@ -129,32 +111,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+	service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
+		userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 	return nil
 }
-
-//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
-//	var promptTokens int
-//	var err error
-//	switch info.RelayMode {
-//	default:
-//		promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
-//	}
-//	info.PromptTokens = promptTokens
-//	return promptTokens, err
-//}
-
-//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
-//	var err error
-//	switch info.RelayMode {
-//	case relayconstant.RelayModeChatCompletions:
-//		err = service.CheckSensitiveMessages(textRequest.Messages)
-//	case relayconstant.RelayModeCompletions:
-//		err = service.CheckSensitiveInput(textRequest.Prompt)
-//	case relayconstant.RelayModeModerations:
-//		err = service.CheckSensitiveInput(textRequest.Input)
-//	case relayconstant.RelayModeEmbeddings:
-//		err = service.CheckSensitiveInput(textRequest.Input)
-//	}
-//	return err
-//}

+ 88 - 52
service/quota.go

@@ -3,7 +3,6 @@ package service
 import (
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"math"
 	"one-api/common"
 	"one-api/dto"
@@ -12,8 +11,47 @@ import (
 	"one-api/setting"
 	"strings"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 
+type TokenDetails struct {
+	TextTokens  int
+	AudioTokens int
+}
+
+type QuotaInfo struct {
+	InputDetails  TokenDetails
+	OutputDetails TokenDetails
+	ModelName     string
+	UsePrice      bool
+	ModelPrice    float64
+	ModelRatio    float64
+	GroupRatio    float64
+}
+
+func calculateAudioQuota(info QuotaInfo) int {
+	if info.UsePrice {
+		return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
+	}
+
+	completionRatio := common.GetCompletionRatio(info.ModelName)
+	audioRatio := common.GetAudioRatio(info.ModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
+	ratio := info.GroupRatio * info.ModelRatio
+
+	quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
+	quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
+		int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
+
+	quota = int(math.Round(float64(quota) * ratio))
+	if ratio != 0 && quota <= 0 {
+		quota = 1
+	}
+
+	return quota
+}
+
 func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
 	if relayInfo.UsePrice {
 		return nil
@@ -33,23 +71,26 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 	textOutTokens := usage.OutputTokenDetails.TextTokens
 	audioInputTokens := usage.InputTokenDetails.AudioTokens
 	audioOutTokens := usage.OutputTokenDetails.AudioTokens
-
-	completionRatio := common.GetCompletionRatio(modelName)
-	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
-	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
 	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 	modelRatio := common.GetModelRatio(modelName)
 
-	ratio := groupRatio * modelRatio
-
-	quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
-	quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
-
-	quota = int(math.Round(float64(quota) * ratio))
-	if ratio != 0 && quota <= 0 {
-		quota = 1
+	quotaInfo := QuotaInfo{
+		InputDetails: TokenDetails{
+			TextTokens:  textInputTokens,
+			AudioTokens: audioInputTokens,
+		},
+		OutputDetails: TokenDetails{
+			TextTokens:  textOutTokens,
+			AudioTokens: audioOutTokens,
+		},
+		ModelName:  modelName,
+		UsePrice:   relayInfo.UsePrice,
+		ModelRatio: modelRatio,
+		GroupRatio: groupRatio,
 	}
 
+	quota := calculateAudioQuota(quotaInfo)
+
 	if userQuota < quota {
 		return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
 	}
@@ -67,8 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 }
 
 func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
-	usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
-	groupRatio float64,
+	usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 	modelPrice float64, usePrice bool, extraContent string) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
@@ -83,17 +123,23 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
 	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
 
-	quota := 0
-	if !usePrice {
-		quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio))
-		quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio))
-		quota = int(math.Round(float64(quota) * ratio))
-		if ratio != 0 && quota <= 0 {
-			quota = 1
-		}
-	} else {
-		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	quotaInfo := QuotaInfo{
+		InputDetails: TokenDetails{
+			TextTokens:  textInputTokens,
+			AudioTokens: audioInputTokens,
+		},
+		OutputDetails: TokenDetails{
+			TextTokens:  textOutTokens,
+			AudioTokens: audioOutTokens,
+		},
+		ModelName:  modelName,
+		UsePrice:   usePrice,
+		ModelRatio: modelRatio,
+		GroupRatio: groupRatio,
 	}
+
+	quota := calculateAudioQuota(quotaInfo)
+
 	totalTokens := usage.TotalTokens
 	var logContent string
 	if !usePrice {
@@ -111,21 +157,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 	} else {
-		//if sensitiveResp != nil {
-		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
-		//}
-		//quotaDelta := quota - preConsumedQuota
-		//if quotaDelta != 0 {
-		//	err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
-		//	if err != nil {
-		//		common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-		//	}
-		//}
-
-		//err := model.CacheUpdateUserQuota(relayInfo.UserId)
-		//if err != nil {
-		//	common.LogError(ctx, "error update user quota cache: "+err.Error())
-		//}
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
@@ -140,8 +171,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 }
 
 func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
-	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
-	groupRatio float64,
+	usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 	modelPrice float64, usePrice bool, extraContent string) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
@@ -156,17 +186,23 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
 	audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.UpstreamModelName)
 
-	quota := 0
-	if !usePrice {
-		quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio))
-		quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio))
-		quota = int(math.Round(float64(quota) * ratio))
-		if ratio != 0 && quota <= 0 {
-			quota = 1
-		}
-	} else {
-		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	quotaInfo := QuotaInfo{
+		InputDetails: TokenDetails{
+			TextTokens:  textInputTokens,
+			AudioTokens: audioInputTokens,
+		},
+		OutputDetails: TokenDetails{
+			TextTokens:  textOutTokens,
+			AudioTokens: audioOutTokens,
+		},
+		ModelName:  relayInfo.UpstreamModelName,
+		UsePrice:   usePrice,
+		ModelRatio: modelRatio,
+		GroupRatio: groupRatio,
 	}
+
+	quota := calculateAudioQuota(quotaInfo)
+
 	totalTokens := usage.TotalTokens
 	var logContent string
 	if !usePrice {