Просмотр исходного кода

支持设置模型按次计费

CaIon 2 лет назад
Родитель
Сommit
3475643257
6 измененных файлов с 87 добавлено и 17 удалено
  1. 29 0
      common/model-ratio.go
  2. 33 14
      controller/relay-text.go
  3. 1 1
      model/ability.go
  4. 1 1
      model/channel.go
  5. 3 0
      model/option.go
  6. 20 1
      web/src/components/OperationSetting.js

+ 29 - 0
common/model-ratio.go

@@ -77,6 +77,35 @@ var ModelRatio = map[string]float64{
 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 }
 
+var ModelPrice = map[string]float64{
+	"gpt-4-gizmo-*": 0.1,
+}
+
+func ModelPrice2JSONString() string {
+	jsonBytes, err := json.Marshal(ModelPrice)
+	if err != nil {
+		SysError("error marshalling model price: " + err.Error())
+	}
+	return string(jsonBytes)
+}
+
+func UpdateModelPriceByJSONString(jsonStr string) error {
+	ModelPrice = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &ModelPrice)
+}
+
+func GetModelPrice(name string) float64 {
+	if strings.HasPrefix(name, "gpt-4-gizmo") {
+		name = "gpt-4-gizmo-*"
+	}
+	price, ok := ModelPrice[name]
+	if !ok {
+		//SysError("model price not found: " + name)
+		return -1
+	}
+	return price
+}
+
 func ModelRatio2JSONString() string {
 	jsonBytes, err := json.Marshal(ModelRatio)
 	if err != nil {

+ 33 - 14
controller/relay-text.go

@@ -231,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	case RelayModeModerations:
 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
 	}
-	preConsumedTokens := common.PreConsumedQuota
-	if textRequest.MaxTokens != 0 {
-		preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
-	}
-	modelRatio := common.GetModelRatio(textRequest.Model)
+	modelPrice := common.GetModelPrice(textRequest.Model)
 	groupRatio := common.GetGroupRatio(group)
-	ratio := modelRatio * groupRatio
-	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+
+	var preConsumedQuota int
+	var ratio float64
+	var modelRatio float64
+	if modelPrice == -1 {
+		preConsumedTokens := common.PreConsumedQuota
+		if textRequest.MaxTokens != 0 {
+			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
+		}
+		modelRatio = common.GetModelRatio(textRequest.Model)
+		ratio = modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -447,15 +457,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	defer func(ctx context.Context) {
 		// c.Writer.Flush()
 		go func() {
-			quota := 0
-			completionRatio := common.GetCompletionRatio(textRequest.Model)
 			promptTokens = textResponse.Usage.PromptTokens
 			completionTokens = textResponse.Usage.CompletionTokens
 
-			quota = promptTokens + int(float64(completionTokens)*completionRatio)
-			quota = int(float64(quota) * ratio)
-			if ratio != 0 && quota <= 0 {
-				quota = 1
+			quota := 0
+			if modelPrice == -1 {
+				completionRatio := common.GetCompletionRatio(textRequest.Model)
+				quota = promptTokens + int(float64(completionTokens)*completionRatio)
+				quota = int(float64(quota) * ratio)
+				if ratio != 0 && quota <= 0 {
+					quota = 1
+				}
+			} else {
+				quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
 			}
 			totalTokens := promptTokens + completionTokens
 			if totalTokens == 0 {
@@ -474,7 +488,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			// record all the consume log even if quota is 0
 			useTimeSeconds := time.Now().Unix() - startTime.Unix()
-			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
+			var logContent string
+			if modelPrice == -1 {
+				logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
+			} else {
+				logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
+			}
 			model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 			model.UpdateChannelUsedQuota(channelId, quota)

+ 1 - 1
model/ability.go

@@ -6,7 +6,7 @@ import (
 )
 
 type Ability struct {
-	Group     string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
+	Group     string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 	Enabled   bool   `json:"enabled"`

+ 1 - 1
model/channel.go

@@ -21,7 +21,7 @@ type Channel struct {
 	Balance            float64 `json:"balance"` // in USD
 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 	Models             string  `json:"models"`
-	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
+	Group              string  `json:"group" gorm:"type:varchar(255);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`

+ 3 - 0
model/option.go

@@ -70,6 +70,7 @@ func InitOptionMap() {
 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
+	common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
 	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	common.OptionMap["ChatLink"] = common.ChatLink
@@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
 		err = common.UpdateModelRatioByJSONString(value)
 	case "GroupRatio":
 		err = common.UpdateGroupRatioByJSONString(value)
+	case "ModelPrice":
+		err = common.UpdateModelPriceByJSONString(value)
 	case "TopUpLink":
 		common.TopUpLink = value
 	case "ChatLink":

+ 20 - 1
web/src/components/OperationSetting.js

@@ -10,6 +10,7 @@ const OperationSetting = () => {
         QuotaRemindThreshold: 0,
         PreConsumedQuota: 0,
         ModelRatio: '',
+        ModelPrice: '',
         GroupRatio: '',
         TopUpLink: '',
         ChatLink: '',
@@ -30,7 +31,7 @@ const OperationSetting = () => {
     if (success) {
       let newInputs = {};
       data.forEach((item) => {
-        if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+        if (item.key === 'ModelRatio' || item.key === 'GroupRatio'|| item.key === 'ModelPrice') {
           item.value = JSON.stringify(JSON.parse(item.value), null, 2);
         }
         newInputs[item.key] = item.value;
@@ -97,6 +98,13 @@ const OperationSetting = () => {
           }
           await updateOption('GroupRatio', inputs.GroupRatio);
         }
+          if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
+              if (!verifyJSON(inputs.ModelPrice)) {
+                  showError('模型固定价格不是合法的 JSON 字符串');
+                  return;
+              }
+              await updateOption('ModelPrice', inputs.ModelPrice);
+          }
         break;
       case 'quota':
         if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
@@ -315,6 +323,17 @@ const OperationSetting = () => {
                     <Header as='h3'>
                         倍率设置
                     </Header>
+                    <Form.Group widths='equal'>
+                        <Form.TextArea
+                            label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
+                            name='ModelPrice'
+                            onChange={handleInputChange}
+                            style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
+                            autoComplete='new-password'
+                            value={inputs.ModelPrice}
+                            placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1,一次消耗0.1刀'
+                        />
+                    </Form.Group>
                     <Form.Group widths='equal'>
                         <Form.TextArea
                             label='模型倍率'