|
@@ -15,6 +15,7 @@ import (
|
|
|
"one-api/model"
|
|
"one-api/model"
|
|
|
relaycommon "one-api/relay/common"
|
|
relaycommon "one-api/relay/common"
|
|
|
relayconstant "one-api/relay/constant"
|
|
relayconstant "one-api/relay/constant"
|
|
|
|
|
+ "one-api/relay/helper"
|
|
|
"one-api/service"
|
|
"one-api/service"
|
|
|
"one-api/setting"
|
|
"one-api/setting"
|
|
|
"strings"
|
|
"strings"
|
|
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // map model name
|
|
|
|
|
- //isModelMapped := false
|
|
|
|
|
- modelMapping := c.GetString("model_mapping")
|
|
|
|
|
- //isModelMapped := false
|
|
|
|
|
- if modelMapping != "" && modelMapping != "{}" {
|
|
|
|
|
- modelMap := make(map[string]string)
|
|
|
|
|
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
|
|
|
- }
|
|
|
|
|
- if modelMap[textRequest.Model] != "" {
|
|
|
|
|
- //isModelMapped = true
|
|
|
|
|
- textRequest.Model = modelMap[textRequest.Model]
|
|
|
|
|
- // set upstream model name
|
|
|
|
|
- //isModelMapped = true
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- relayInfo.UpstreamModelName = textRequest.Model
|
|
|
|
|
- relayInfo.RecodeModelName = textRequest.Model
|
|
|
|
|
- modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
|
|
|
|
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
|
|
|
|
-
|
|
|
|
|
- var preConsumedQuota int
|
|
|
|
|
- var ratio float64
|
|
|
|
|
- var modelRatio float64
|
|
|
|
|
- //err := service.SensitiveWordsCheck(textRequest)
|
|
|
|
|
-
|
|
|
|
|
if setting.ShouldCheckPromptSensitive() {
|
|
if setting.ShouldCheckPromptSensitive() {
|
|
|
err = checkRequestSensitive(textRequest, relayInfo)
|
|
err = checkRequestSensitive(textRequest, relayInfo)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ err = helper.ModelMappedHelper(c, relayInfo)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ textRequest.Model = relayInfo.UpstreamModelName
|
|
|
|
|
+
|
|
|
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
|
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
|
|
var promptTokens int
|
|
var promptTokens int
|
|
|
if value, exists := c.Get("prompt_tokens"); exists {
|
|
if value, exists := c.Get("prompt_tokens"); exists {
|
|
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
c.Set("prompt_tokens", promptTokens)
|
|
c.Set("prompt_tokens", promptTokens)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if !getModelPriceSuccess {
|
|
|
|
|
- preConsumedTokens := common.PreConsumedQuota
|
|
|
|
|
- if textRequest.MaxTokens != 0 {
|
|
|
|
|
- preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
|
|
|
|
- }
|
|
|
|
|
- modelRatio = common.GetModelRatio(textRequest.Model)
|
|
|
|
|
- ratio = modelRatio * groupRatio
|
|
|
|
|
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
|
|
|
|
- } else {
|
|
|
|
|
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
|
|
|
|
|
|
|
// pre-consume quota 预消耗配额
|
|
// pre-consume quota 预消耗配额
|
|
|
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
|
|
|
|
|
|
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
|
|
if openaiErr != nil {
|
|
if openaiErr != nil {
|
|
|
return openaiErr
|
|
return openaiErr
|
|
|
}
|
|
}
|
|
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
|
return openaiErr
|
|
return openaiErr
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
|
|
|
|
|
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
|
|
|
|
|
|
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
|
|
|
|
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
|
|
} else {
|
|
} else {
|
|
|
- postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
|
|
|
|
|
|
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
|
|
}
|
|
}
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
|
|
|
|
- usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
|
|
|
|
- modelPrice float64, usePrice bool, extraContent string) {
|
|
|
|
|
|
|
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
|
|
|
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
|
|
if usage == nil {
|
|
if usage == nil {
|
|
|
usage = &dto.Usage{
|
|
usage = &dto.Usage{
|
|
|
PromptTokens: relayInfo.PromptTokens,
|
|
PromptTokens: relayInfo.PromptTokens,
|
|
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
|
|
promptTokens := usage.PromptTokens
|
|
promptTokens := usage.PromptTokens
|
|
|
completionTokens := usage.CompletionTokens
|
|
completionTokens := usage.CompletionTokens
|
|
|
|
|
+ modelName := relayInfo.OriginModelName
|
|
|
|
|
|
|
|
tokenName := ctx.GetString("token_name")
|
|
tokenName := ctx.GetString("token_name")
|
|
|
completionRatio := common.GetCompletionRatio(modelName)
|
|
completionRatio := common.GetCompletionRatio(modelName)
|
|
|
|
|
+ ratio := priceData.ModelRatio * priceData.GroupRatio
|
|
|
|
|
+ modelRatio := priceData.ModelRatio
|
|
|
|
|
+ groupRatio := priceData.GroupRatio
|
|
|
|
|
+ modelPrice := priceData.ModelPrice
|
|
|
|
|
+ usePrice := priceData.UsePrice
|
|
|
|
|
|
|
|
quota := 0
|
|
quota := 0
|
|
|
- if !usePrice {
|
|
|
|
|
|
|
+ if !priceData.UsePrice {
|
|
|
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
|
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
|
|
quota = int(math.Round(float64(quota) * ratio))
|
|
quota = int(math.Round(float64(quota) * ratio))
|
|
|
if ratio != 0 && quota <= 0 {
|
|
if ratio != 0 && quota <= 0 {
|