Explorar el Código

feat: realtime pre consume

(cherry picked from commit d87917f8f6eb9d2e144a9f840d6d91767ea2eb69)
[email protected] hace 1 año
padre
commit
24b3ed50d7
Se han modificado 4 ficheros con 180 adiciones y 94 borrados
  1. 45 4
      relay/channel/openai/relay-openai.go
  2. 1 0
      relay/common/relay_info.go
  3. 2 90
      relay/websocket.go
  4. 132 0
      service/quota.go

+ 45 - 4
relay/channel/openai/relay-openai.go

@@ -389,6 +389,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 
 	usage := &dto.RealtimeUsage{}
 	localUsage := &dto.RealtimeUsage{}
+	sumUsage := &dto.RealtimeUsage{}
 
 	go func() {
 		for {
@@ -478,6 +479,12 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 						usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
 						usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
 						usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
+						err := preConsumeUsage(c, info, usage, sumUsage)
+						if err != nil {
+							errChan <- fmt.Errorf("error consume usage: %v", err)
+							return
+						}
+						usage = &dto.RealtimeUsage{}
 					} else {
 						textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
 						if err != nil {
@@ -490,7 +497,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 						localUsage.InputTokens += textToken + audioToken
 						localUsage.InputTokenDetails.TextTokens += textToken
 						localUsage.InputTokenDetails.AudioTokens += audioToken
+						err = preConsumeUsage(c, info, localUsage, sumUsage)
+						if err != nil {
+							errChan <- fmt.Errorf("error consume usage: %v", err)
+							return
+						}
+						localUsage = &dto.RealtimeUsage{}
+						// print now usage
 					}
+					common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+					common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+					common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+
 				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
 					realtimeSession := realtimeEvent.Session
 					if realtimeSession != nil {
@@ -528,15 +546,38 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 	select {
 	case <-clientClosed:
 	case <-targetClosed:
-	case <-errChan:
+	case err := <-errChan:
 		//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
+		common.LogError(c, "realtime error: "+err.Error())
 	case <-c.Done():
 	}
 
+	if usage.TotalTokens != 0 {
+		_ = preConsumeUsage(c, info, usage, sumUsage)
+	}
+
+	if localUsage.TotalTokens != 0 {
+		_ = preConsumeUsage(c, info, localUsage, sumUsage)
+	}
+
 	// check usage total tokens, if 0, use local usage
 
-	if usage.TotalTokens == 0 {
-		usage = localUsage
+	return nil, sumUsage
+}
+
+func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
+	totalUsage.TotalTokens += usage.TotalTokens
+	totalUsage.InputTokens += usage.InputTokens
+	totalUsage.OutputTokens += usage.OutputTokens
+	totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
+	totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
+	totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
+	totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
+	totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
+	// clear usage
+	err := service.PreWssConsumeQuota(ctx, info, usage)
+	if err == nil {
+		common.LogInfo(ctx, "realtime streaming consume usage success")
 	}
-	return nil, usage
+	return err
 }

+ 1 - 0
relay/common/relay_info.go

@@ -23,6 +23,7 @@ type RelayInfo struct {
 	ApiType              int
 	IsStream             bool
 	IsPlayground         bool
+	UsePrice             bool
 	RelayMode            int
 	UpstreamModelName    string
 	OriginModelName      string

+ 2 - 90
relay/websocket.go

@@ -5,15 +5,11 @@ import (
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
-	"math"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
-	"one-api/model"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
-	"strings"
-	"time"
 )
 
 //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -91,6 +87,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 	} else {
 		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+		relayInfo.UsePrice = true
 	}
 
 	// pre-consume quota 预消耗配额
@@ -126,95 +123,10 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+	service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 	return nil
 }
 
-func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
-	usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
-	groupRatio float64,
-	modelPrice float64, usePrice bool, extraContent string) {
-
-	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
-	textInputTokens := usage.InputTokenDetails.TextTokens
-	textOutTokens := usage.OutputTokenDetails.TextTokens
-
-	audioInputTokens := usage.InputTokenDetails.AudioTokens
-	audioOutTokens := usage.OutputTokenDetails.AudioTokens
-
-	tokenName := ctx.GetString("token_name")
-	completionRatio := common.GetCompletionRatio(modelName)
-	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
-	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
-
-	quota := 0
-	if !usePrice {
-		quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
-		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
-
-		quota = int(math.Round(float64(quota) * ratio))
-		if ratio != 0 && quota <= 0 {
-			quota = 1
-		}
-	} else {
-		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-	}
-	totalTokens := usage.TotalTokens
-	var logContent string
-	if !usePrice {
-		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
-	} else {
-		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
-	}
-
-	// record all the consume log even if quota is 0
-	if totalTokens == 0 {
-		// in this case, must be some error happened
-		// we cannot just return, because we may have to return the pre-consumed quota
-		quota = 0
-		logContent += fmt.Sprintf("(可能是上游超时)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
-	} else {
-		//if sensitiveResp != nil {
-		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
-		//}
-		quotaDelta := quota - preConsumedQuota
-		if quotaDelta != 0 {
-			err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
-			if err != nil {
-				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-			}
-		}
-		err := model.CacheUpdateUserQuota(relayInfo.UserId)
-		if err != nil {
-			common.LogError(ctx, "error update user quota cache: "+err.Error())
-		}
-		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
-		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
-	}
-
-	logModel := modelName
-	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
-		logModel = "gpt-4-gizmo-*"
-		logContent += fmt.Sprintf(",模型 %s", modelName)
-	}
-	if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
-		logModel = "gpt-4o-gizmo-*"
-		logContent += fmt.Sprintf(",模型 %s", modelName)
-	}
-	if extraContent != "" {
-		logContent += ", " + extraContent
-	}
-	other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
-		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
-
-	//if quota != 0 {
-	//
-	//}
-}
-
 //func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
 //	var promptTokens int
 //	var err error

+ 132 - 0
service/quota.go

@@ -0,0 +1,132 @@
+package service
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"math"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+	relaycommon "one-api/relay/common"
+	"strings"
+	"time"
+)
+
+func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
+	if relayInfo.UsePrice {
+		return nil
+	}
+	modelName := relayInfo.UpstreamModelName
+	textInputTokens := usage.InputTokenDetails.TextTokens
+	textOutTokens := usage.OutputTokenDetails.TextTokens
+	audioInputTokens := usage.InputTokenDetails.AudioTokens
+	audioOutTokens := usage.OutputTokenDetails.AudioTokens
+
+	completionRatio := common.GetCompletionRatio(modelName)
+	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	modelRatio := common.GetModelRatio(modelName)
+
+	ratio := groupRatio * modelRatio
+
+	quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
+	quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
+
+	quota = int(math.Round(float64(quota) * ratio))
+	if ratio != 0 && quota <= 0 {
+		quota = 1
+	}
+
+	err := model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
+	if err != nil {
+		return err
+	}
+	err = model.CacheUpdateUserQuota(relayInfo.UserId)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
+	usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
+	groupRatio float64,
+	modelPrice float64, usePrice bool, extraContent string) {
+
+	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+	textInputTokens := usage.InputTokenDetails.TextTokens
+	textOutTokens := usage.OutputTokenDetails.TextTokens
+
+	audioInputTokens := usage.InputTokenDetails.AudioTokens
+	audioOutTokens := usage.OutputTokenDetails.AudioTokens
+
+	tokenName := ctx.GetString("token_name")
+	completionRatio := common.GetCompletionRatio(modelName)
+	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
+
+	quota := 0
+	if !usePrice {
+		quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
+		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
+
+		quota = int(math.Round(float64(quota) * ratio))
+		if ratio != 0 && quota <= 0 {
+			quota = 1
+		}
+	} else {
+		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	totalTokens := usage.TotalTokens
+	var logContent string
+	if !usePrice {
+		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
+	} else {
+		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+	}
+
+	// record all the consume log even if quota is 0
+	if totalTokens == 0 {
+		// in this case, must be some error happened
+		// we cannot just return, because we may have to return the pre-consumed quota
+		quota = 0
+		logContent += fmt.Sprintf("(可能是上游超时)")
+		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+	} else {
+		//if sensitiveResp != nil {
+		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
+		//}
+		//quotaDelta := quota - preConsumedQuota
+		//if quotaDelta != 0 {
+		//	err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+		//	if err != nil {
+		//		common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		//	}
+		//}
+
+		//err := model.CacheUpdateUserQuota(relayInfo.UserId)
+		//if err != nil {
+		//	common.LogError(ctx, "error update user quota cache: "+err.Error())
+		//}
+		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+	}
+
+	logModel := modelName
+	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
+		logModel = "gpt-4-gizmo-*"
+		logContent += fmt.Sprintf(",模型 %s", modelName)
+	}
+	if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
+		logModel = "gpt-4o-gizmo-*"
+		logContent += fmt.Sprintf(",模型 %s", modelName)
+	}
+	if extraContent != "" {
+		logContent += ", " + extraContent
+	}
+	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
+		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
+}