Просмотр исходного кода

refactor: Simplify model mapping and pricing logic across relay modules

[email protected] 10 месяцев назад
Родитель
Сommit
06da65a9d0

+ 1 - 1
controller/relay.go

@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	var err *dto.OpenAIErrorWithStatusCode
 	switch relayMode {
 	case relayconstant.RelayModeImagesGenerations:
-		err = relay.ImageHelper(c, relayMode)
+		err = relay.ImageHelper(c)
 	case relayconstant.RelayModeAudioSpeech:
 		fallthrough
 	case relayconstant.RelayModeAudioTranslation:

+ 2 - 2
model/option.go

@@ -84,7 +84,7 @@ func InitOptionMap() {
 	common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
 	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
-	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+	common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 	common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
 	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
 		common.QuotaForInvitee, _ = strconv.Atoi(value)
 	case "QuotaRemindThreshold":
 		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
-	case "PreConsumedQuota":
+	case "ShouldPreConsumedQuota":
 		common.PreConsumedQuota, _ = strconv.Atoi(value)
 	case "RetryTimes":
 		common.RetryTimes, _ = strconv.Atoi(value)

+ 39 - 28
relay/common/relay_info.go

@@ -13,24 +13,24 @@ import (
 )
 
 type RelayInfo struct {
-	ChannelType          int
-	ChannelId            int
-	TokenId              int
-	TokenKey             string
-	UserId               int
-	Group                string
-	TokenUnlimited       bool
-	StartTime            time.Time
-	FirstResponseTime    time.Time
-	setFirstResponse     bool
-	ApiType              int
-	IsStream             bool
-	IsPlayground         bool
-	UsePrice             bool
-	RelayMode            int
-	UpstreamModelName    string
-	OriginModelName      string
-	RecodeModelName      string
+	ChannelType       int
+	ChannelId         int
+	TokenId           int
+	TokenKey          string
+	UserId            int
+	Group             string
+	TokenUnlimited    bool
+	StartTime         time.Time
+	FirstResponseTime time.Time
+	setFirstResponse  bool
+	ApiType           int
+	IsStream          bool
+	IsPlayground      bool
+	UsePrice          bool
+	RelayMode         int
+	UpstreamModelName string
+	OriginModelName   string
+	//RecodeModelName      string
 	RequestURLPath       string
 	ApiVersion           string
 	PromptTokens         int
@@ -39,6 +39,7 @@ type RelayInfo struct {
 	BaseUrl              string
 	SupportStreamOptions bool
 	ShouldIncludeUsage   bool
+	IsModelMapped        bool
 	ClientWs             *websocket.Conn
 	TargetWs             *websocket.Conn
 	InputAudioFormat     string
@@ -50,6 +51,18 @@ type RelayInfo struct {
 	ChannelSetting       map[string]interface{}
 }
 
+// 定义支持流式选项的通道类型
+var streamSupportedChannels = map[int]bool{
+	common.ChannelTypeOpenAI:     true,
+	common.ChannelTypeAnthropic:  true,
+	common.ChannelTypeAws:        true,
+	common.ChannelTypeGemini:     true,
+	common.ChannelCloudflare:     true,
+	common.ChannelTypeAzure:      true,
+	common.ChannelTypeVolcEngine: true,
+	common.ChannelTypeOllama:     true,
+}
+
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
 	info := GenRelayInfo(c)
 	info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		FirstResponseTime: startTime.Add(-time.Second),
 		OriginModelName:   c.GetString("original_model"),
 		UpstreamModelName: c.GetString("original_model"),
-		RecodeModelName:   c.GetString("recode_model"),
-		ApiType:           apiType,
-		ApiVersion:        c.GetString("api_version"),
-		ApiKey:            strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
-		Organization:      c.GetString("channel_organization"),
-		ChannelSetting:    channelSetting,
+		//RecodeModelName:   c.GetString("original_model"),
+		IsModelMapped:  false,
+		ApiType:        apiType,
+		ApiVersion:     c.GetString("api_version"),
+		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+		Organization:   c.GetString("channel_organization"),
+		ChannelSetting: channelSetting,
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/pg") {
 		info.IsPlayground = true
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if info.ChannelType == common.ChannelTypeVertexAi {
 		info.ApiVersion = c.GetString("region")
 	}
-	if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
-		info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
-		info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
-		info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
+	if streamSupportedChannels[info.ChannelType] {
 		info.SupportStreamOptions = true
 	}
 	return info

+ 25 - 0
relay/helper/model_mapped.go

@@ -0,0 +1,25 @@
+package helper
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"one-api/relay/common"
+)
+
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	if modelMapping != "" && modelMapping != "{}" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return fmt.Errorf("unmarshal_model_mapping_failed")
+		}
+		if modelMap[info.OriginModelName] != "" {
+			info.UpstreamModelName = modelMap[info.OriginModelName]
+			info.IsModelMapped = true
+		}
+	}
+	return nil
+}

+ 41 - 0
relay/helper/price.go

@@ -0,0 +1,41 @@
+package helper
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	relaycommon "one-api/relay/common"
+	"one-api/setting"
+)
+
+type PriceData struct {
+	ModelPrice             float64
+	ModelRatio             float64
+	GroupRatio             float64
+	UsePrice               bool
+	ShouldPreConsumedQuota int
+}
+
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
+	modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
+	groupRatio := setting.GetGroupRatio(info.Group)
+	var preConsumedQuota int
+	var modelRatio float64
+	if !usePrice {
+		preConsumedTokens := common.PreConsumedQuota
+		if maxTokens != 0 {
+			preConsumedTokens = promptTokens + maxTokens
+		}
+		modelRatio = common.GetModelRatio(info.OriginModelName)
+		ratio := modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	return PriceData{
+		ModelPrice:             modelPrice,
+		ModelRatio:             modelRatio,
+		GroupRatio:             groupRatio,
+		UsePrice:               usePrice,
+		ShouldPreConsumedQuota: preConsumedQuota,
+	}
+}

+ 10 - 19
relay/relay-audio.go

@@ -1,7 +1,6 @@
 package relay
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
@@ -11,6 +10,7 @@ import (
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting"
 )
@@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		relayInfo.PromptTokens = promptTokens
 	}
 
-	modelRatio := common.GetModelRatio(audioRequest.Model)
-	groupRatio := setting.GetGroupRatio(relayInfo.Group)
-	ratio := modelRatio * groupRatio
-	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+	priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
+
 	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
-	preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if openaiErr != nil {
 		return openaiErr
 	}
@@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		}
 	}()
 
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	if modelMapping != "" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[audioRequest.Model] != "" {
-			audioRequest.Model = modelMap[audioRequest.Model]
-		}
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
-	relayInfo.UpstreamModelName = audioRequest.Model
+
+	audioRequest.Model = relayInfo.UpstreamModelName
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
@@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		return openaiErr
 	}
 
-	postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 
 	return nil
 }

+ 16 - 15
relay/relay-image.go

@@ -12,6 +12,7 @@ import (
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting"
 	"strings"
@@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 	return imageRequest, nil
 }
 
-func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 	imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
 	}
 
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	if modelMapping != "" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[imageRequest.Model] != "" {
-			imageRequest.Model = modelMap[imageRequest.Model]
-		}
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
-	relayInfo.UpstreamModelName = imageRequest.Model
+
+	imageRequest.Model = relayInfo.UpstreamModelName
 
 	modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
 	if !success {
@@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 		quality = "hd"
 	}
 
-	logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
-	postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
+	priceData := helper.PriceData{
+		UsePrice:               true,
+		GroupRatio:             groupRatio,
+		ModelPrice:             modelPrice,
+		ModelRatio:             0,
+		ShouldPreConsumedQuota: 0,
+	}
 
+	logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+	postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
 	return nil
 }

+ 22 - 46
relay/relay-text.go

@@ -15,6 +15,7 @@ import (
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting"
 	"strings"
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 	}
 
-	// map model name
-	//isModelMapped := false
-	modelMapping := c.GetString("model_mapping")
-	//isModelMapped := false
-	if modelMapping != "" && modelMapping != "{}" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[textRequest.Model] != "" {
-			//isModelMapped = true
-			textRequest.Model = modelMap[textRequest.Model]
-			// set upstream model name
-			//isModelMapped = true
-		}
-	}
-	relayInfo.UpstreamModelName = textRequest.Model
-	relayInfo.RecodeModelName = textRequest.Model
-	modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
-	groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
-	var preConsumedQuota int
-	var ratio float64
-	var modelRatio float64
-	//err := service.SensitiveWordsCheck(textRequest)
-
 	if setting.ShouldCheckPromptSensitive() {
 		err = checkRequestSensitive(textRequest, relayInfo)
 		if err != nil {
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		}
 	}
 
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
+	}
+
+	textRequest.Model = relayInfo.UpstreamModelName
+
 	// 获取 promptTokens,如果上下文中已经存在,则直接使用
 	var promptTokens int
 	if value, exists := c.Get("prompt_tokens"); exists {
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		c.Set("prompt_tokens", promptTokens)
 	}
 
-	if !getModelPriceSuccess {
-		preConsumedTokens := common.PreConsumedQuota
-		if textRequest.MaxTokens != 0 {
-			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
-		}
-		modelRatio = common.GetModelRatio(textRequest.Model)
-		ratio = modelRatio * groupRatio
-		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
-	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-	}
+	priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
 
 	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if openaiErr != nil {
 		return openaiErr
 	}
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		return openaiErr
 	}
 
-	if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
-		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+	if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
+		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 	} else {
-		postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+		postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 	}
 	return nil
 }
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
 	}
 }
 
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
-	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64, usePrice bool, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+	usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
 	if usage == nil {
 		usage = &dto.Usage{
 			PromptTokens:     relayInfo.PromptTokens,
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
 	completionTokens := usage.CompletionTokens
+	modelName := relayInfo.OriginModelName
 
 	tokenName := ctx.GetString("token_name")
 	completionRatio := common.GetCompletionRatio(modelName)
+	ratio := priceData.ModelRatio * priceData.GroupRatio
+	modelRatio := priceData.ModelRatio
+	groupRatio := priceData.GroupRatio
+	modelPrice := priceData.ModelPrice
+	usePrice := priceData.UsePrice
 
 	quota := 0
-	if !usePrice {
+	if !priceData.UsePrice {
 		quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
 		quota = int(math.Round(float64(quota) * ratio))
 		if ratio != 0 && quota <= 0 {

+ 9 - 32
relay/relay_embedding.go

@@ -10,8 +10,8 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
-	"one-api/setting"
 )
 
 func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
 		return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
 	}
 
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	//isModelMapped := false
-	if modelMapping != "" && modelMapping != "{}" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[embeddingRequest.Model] != "" {
-			embeddingRequest.Model = modelMap[embeddingRequest.Model]
-			// set upstream model name
-			//isModelMapped = true
-		}
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	relayInfo.UpstreamModelName = embeddingRequest.Model
-	modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
-	groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
-	var preConsumedQuota int
-	var ratio float64
-	var modelRatio float64
+	embeddingRequest.Model = relayInfo.UpstreamModelName
 
 	promptToken := getEmbeddingPromptToken(*embeddingRequest)
-	if !success {
-		preConsumedTokens := promptToken
-		modelRatio = common.GetModelRatio(embeddingRequest.Model)
-		ratio = modelRatio * groupRatio
-		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
-	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-	}
 	relayInfo.PromptTokens = promptToken
 
+	priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
 	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if openaiErr != nil {
 		return openaiErr
 	}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 	return nil
 }

+ 9 - 32
relay/relay_rerank.go

@@ -9,8 +9,8 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
-	"one-api/setting"
 )
 
 func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
 	}
 
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	//isModelMapped := false
-	if modelMapping != "" && modelMapping != "{}" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[rerankRequest.Model] != "" {
-			rerankRequest.Model = modelMap[rerankRequest.Model]
-			// set upstream model name
-			//isModelMapped = true
-		}
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 
-	relayInfo.UpstreamModelName = rerankRequest.Model
-	modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
-	groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
-	var preConsumedQuota int
-	var ratio float64
-	var modelRatio float64
+	rerankRequest.Model = relayInfo.UpstreamModelName
 
 	promptToken := getRerankPromptToken(*rerankRequest)
-	if !success {
-		preConsumedTokens := promptToken
-		modelRatio = common.GetModelRatio(rerankRequest.Model)
-		ratio = modelRatio * groupRatio
-		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
-	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-	}
 	relayInfo.PromptTokens = promptToken
 
+	priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
 	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if openaiErr != nil {
 		return openaiErr
 	}
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 	return nil
 }

+ 4 - 0
service/log_info_generate.go

@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
 	if relayInfo.ReasoningEffort != "" {
 		other["reasoning_effort"] = relayInfo.ReasoningEffort
 	}
+	if relayInfo.IsModelMapped {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = relayInfo.UpstreamModelName
+	}
 	adminInfo := make(map[string]interface{})
 	adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
 	other["admin_info"] = adminInfo

+ 15 - 10
service/quota.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/setting"
 	"strings"
 	"time"
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 		return err
 	}
 
-	modelName := relayInfo.UpstreamModelName
+	modelName := relayInfo.OriginModelName
 	textInputTokens := usage.InputTokenDetails.TextTokens
 	textOutTokens := usage.OutputTokenDetails.TextTokens
 	audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 
 	tokenName := ctx.GetString("token_name")
 	completionRatio := common.GetCompletionRatio(modelName)
-	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+	audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
 	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
 
 	quotaInfo := QuotaInfo{
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 }
 
 func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
-	usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64, usePrice bool, extraContent string) {
+	usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	audioOutTokens := usage.CompletionTokenDetails.AudioTokens
 
 	tokenName := ctx.GetString("token_name")
-	completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
-	audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
-	audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
+	completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
+	audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
+
+	modelRatio := priceData.ModelRatio
+	groupRatio := priceData.GroupRatio
+	modelPrice := priceData.ModelPrice
+	usePrice := priceData.UsePrice
 
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 			TextTokens:  textOutTokens,
 			AudioTokens: audioOutTokens,
 		},
-		ModelName:  relayInfo.RecodeModelName,
+		ModelName:  relayInfo.OriginModelName,
 		UsePrice:   usePrice,
 		ModelRatio: modelRatio,
 		GroupRatio: groupRatio,
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游超时)")
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
 	} else {
 		quotaDelta := quota - preConsumedQuota
 		if quotaDelta != 0 {
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
-	logModel := relayInfo.RecodeModelName
+	logModel := relayInfo.OriginModelName
 	if extraContent != "" {
 		logContent += ", " + extraContent
 	}

+ 86 - 14
web/src/components/LogsTable.js

@@ -15,7 +15,7 @@ import {
   Button, Descriptions,
   Form,
   Layout,
-  Modal,
+  Modal, Popover,
   Select,
   Space,
   Spin,
@@ -34,6 +34,7 @@ import {
 import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
 import { getLogOther } from '../helpers/other.js';
 import { StyleContext } from '../context/Style/index.js';
+import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
 
 const { Header } = Layout;
 
@@ -141,7 +142,78 @@ const LogsTable = () => {
         </Tag>
       );
     }
-  } 
+  }
+
+  function renderModelName(record) {
+
+    let other = getLogOther(record.other);
+    let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
+    if (!modelMapped) {
+      return <Tag
+        color={stringToColor(record.model_name)}
+        size='large'
+        onClick={(event) => {
+          copyText(event, record.model_name).then(r => {});
+        }}
+      >
+        {' '}{record.model_name}{' '}
+      </Tag>;
+    } else {
+      return (
+        <>
+          <Space vertical align={'start'}>
+            <Popover content={
+              <div style={{padding: 10}}> 
+                <Space vertical align={'start'}>
+                  <Tag
+                    color={stringToColor(record.model_name)}
+                    size='large'
+                    onClick={(event) => {
+                      copyText(event, record.model_name).then(r => {});
+                    }}
+                  >
+                    {t('请求并计费模型')}{' '}{record.model_name}{' '}
+                  </Tag>
+                  <Tag
+                    color={stringToColor(other.upstream_model_name)}
+                    size='large'
+                    onClick={(event) => {
+                      copyText(event, other.upstream_model_name).then(r => {});
+                    }}
+                  >
+                    {t('实际模型')}{' '}{other.upstream_model_name}{' '}
+                  </Tag>
+                </Space>
+              </div>
+            }>
+              <Tag
+                color={stringToColor(record.model_name)}
+                size='large'
+                onClick={(event) => {
+                  copyText(event, record.model_name).then(r => {});
+                }}
+                suffixIcon={<IconRefresh />}
+              >
+                {' '}{record.model_name}{' '}
+              </Tag>
+            </Popover>
+            {/*<Tooltip content={t('实际模型')}>*/}
+            {/*  <Tag*/}
+            {/*    color={stringToColor(other.upstream_model_name)}*/}
+            {/*    size='large'*/}
+            {/*    onClick={(event) => {*/}
+            {/*      copyText(event, other.upstream_model_name).then(r => {});*/}
+            {/*    }}*/}
+            {/*  >*/}
+            {/*    {' '}{other.upstream_model_name}{' '}*/}
+            {/*  </Tag>*/}
+            {/*</Tooltip>*/}
+          </Space>
+        </>
+      );
+    }
+
+  }
 
   const columns = [
     {
@@ -272,18 +344,7 @@ const LogsTable = () => {
       dataIndex: 'model_name',
       render: (text, record, index) => {
         return record.type === 0 || record.type === 2 ? (
-          <>
-            <Tag
-              color={stringToColor(text)}
-              size='large'
-              onClick={(event) => {
-                copyText(event, text);
-              }}
-            >
-              {' '}
-              {text}{' '}
-            </Tag>
-          </>
+          <>{renderModelName(record)}</>
         ) : (
           <></>
         );
@@ -580,6 +641,17 @@ const LogsTable = () => {
         value: logs[i].content,
       });
       if (logs[i].type === 2) {
+        let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
+        if (modelMapped) {
+          expandDataLocal.push({
+            key: t('请求并计费模型'),
+            value: logs[i].model_name,
+          });
+          expandDataLocal.push({
+            key: t('实际模型'),
+            value: other.upstream_model_name,
+          });
+        }
         let content = '';
         if (other?.ws || other?.audio) {
           content = renderAudioModelPrice(