Browse Source

refactor: Improve channel testing and model price handling

[email protected] 10 months ago
parent
commit
d042a1bd55
2 changed files with 27 additions and 32 deletions
  1. 23 32
      controller/channel-test.go
  2. 4 0
      relay/helper/price.go

+ 23 - 32
controller/channel-test.go

@@ -17,8 +17,8 @@ import (
 	"one-api/relay"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
-	"one-api/setting"
 	"strconv"
 	"strings"
 	"sync"
@@ -73,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 		}
 	}
 
-	modelMapping := *channel.ModelMapping
-	if modelMapping != "" && modelMapping != "{}" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-		}
-		if modelMap[testModel] != "" {
-			testModel = modelMap[testModel]
-		}
-	}
-
 	cache, err := model.GetUserCache(1)
 	if err != nil {
 		return err, nil
@@ -98,7 +86,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 	middleware.SetupContextForSelectedChannel(c, channel, testModel)
 
-	meta := relaycommon.GenRelayInfo(c)
+	info := relaycommon.GenRelayInfo(c)
+
+	err = helper.ModelMappedHelper(c, info)
+	if err != nil {
+		return err, nil
+	}
+
 	apiType, _ := constant.ChannelType2APIType(channel.Type)
 	adaptor := relay.GetAdaptor(apiType)
 	if adaptor == nil {
@@ -106,12 +100,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 
 	request := buildTestRequest(testModel)
-	meta.UpstreamModelName = testModel
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
+	info.OriginModelName = testModel
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
 
-	adaptor.Init(meta)
+	adaptor.Init(info)
 
-	convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
+	convertedRequest, err := adaptor.ConvertRequest(c, info, request)
 	if err != nil {
 		return err, nil
 	}
@@ -121,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 	requestBody := bytes.NewBuffer(jsonData)
 	c.Request.Body = io.NopCloser(requestBody)
-	resp, err := adaptor.DoRequest(c, meta, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 		return err, nil
 	}
@@ -133,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 			return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
 		}
 	}
-	usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
+	usageA, respErr := adaptor.DoResponse(c, httpResp, info)
 	if respErr != nil {
 		return fmt.Errorf("%s", respErr.Error.Message), respErr
 	}
@@ -146,27 +140,24 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	if err != nil {
 		return err, nil
 	}
-	modelPrice, usePrice := setting.GetModelPrice(testModel, false)
-	modelRatio, success := setting.GetModelRatio(testModel)
-	if !usePrice && !success {
-		return fmt.Errorf("模型 %s 倍率和价格均未设置,请设置或者开启自用模式", testModel), nil
+	priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
+	if err != nil {
+		return err, nil
 	}
-	completionRatio := setting.GetCompletionRatio(testModel)
-	ratio := modelRatio
 	quota := 0
-	if !usePrice {
-		quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
-		quota = int(math.Round(float64(quota) * ratio))
-		if ratio != 0 && quota <= 0 {
+	if !priceData.UsePrice {
+		quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
+		quota = int(math.Round(float64(quota) * priceData.ModelRatio))
+		if priceData.ModelRatio != 0 && quota <= 0 {
 			quota = 1
 		}
 	} else {
-		quota = int(modelPrice * common.QuotaPerUnit)
+		quota = int(priceData.ModelPrice * common.QuotaPerUnit)
 	}
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	consumedTime := float64(milliseconds) / 1000.0
-	other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
+	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
 	model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
 		quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))

+ 4 - 0
relay/helper/price.go

@@ -11,6 +11,7 @@ import (
 type PriceData struct {
 	ModelPrice             float64
 	ModelRatio             float64
+	CompletionRatio        float64
 	GroupRatio             float64
 	UsePrice               bool
 	ShouldPreConsumedQuota int
@@ -21,6 +22,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	groupRatio := setting.GetGroupRatio(info.Group)
 	var preConsumedQuota int
 	var modelRatio float64
+	var completionRatio float64
 	if !usePrice {
 		preConsumedTokens := common.PreConsumedQuota
 		if maxTokens != 0 {
@@ -35,6 +37,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 				return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
 			}
 		}
+		completionRatio = setting.GetCompletionRatio(info.OriginModelName)
 		ratio := modelRatio * groupRatio
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 	} else {
@@ -43,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	return PriceData{
 		ModelPrice:             modelPrice,
 		ModelRatio:             modelRatio,
+		CompletionRatio:        completionRatio,
 		GroupRatio:             groupRatio,
 		UsePrice:               usePrice,
 		ShouldPreConsumedQuota: preConsumedQuota,