浏览代码

feat: add temporary TASK_PRICE_PATCH configuration to environment variables

CaIon 2 月之前
父节点
当前提交
43f2a8ac06
共有 3 个文件被更改,包括 37 次插入12 次删除
  1. 14 0
      common/init.go
  2. 3 0
      constant/env.go
  3. 20 12
      relay/relay_task.go

+ 14 - 0
common/init.go

@@ -7,6 +7,7 @@ import (
 	"os"
 	"path/filepath"
 	"strconv"
+	"strings"
 	"time"
 
 	"github.com/QuantumNous/new-api/constant"
@@ -118,4 +119,17 @@ func initConstantEnv() {
 	constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
 	// 是否启用错误日志
 	constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
+
+	soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
+	if soraPatchStr != "" {
+		var taskPricePatches []string
+		soraPatches := strings.Split(soraPatchStr, ",")
+		for _, patch := range soraPatches {
+			trimmedPatch := strings.TrimSpace(patch)
+			if trimmedPatch != "" {
+				taskPricePatches = append(taskPricePatches, trimmedPatch)
+			}
+		}
+		constant.TaskPricePatches = taskPricePatches
+	}
 }

+ 3 - 0
constant/env.go

@@ -13,3 +13,6 @@ var NotifyLimitCount int
 var NotificationLimitDurationMinute int
 var GenerateDefaultToken bool
 var ErrorLogEnabled bool
+
+// temporary variable for sora patch, will be removed in future
+var TaskPricePatches []string

+ 20 - 12
relay/relay_task.go

@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	} else {
 		ratio = modelPrice * groupRatio
 	}
-	if len(info.PriceData.OtherRatios) > 0 {
-		for _, ra := range info.PriceData.OtherRatios {
-			if 1.0 != ra {
-				ratio *= ra
+	// FIXME: 临时修补,支持任务仅按次计费
+	if !common.StringsContains(constant.TaskPricePatches, modelName) {
+		if len(info.PriceData.OtherRatios) > 0 {
+			for _, ra := range info.PriceData.OtherRatios {
+				if 1.0 != ra {
+					ratio *= ra
+				}
 			}
 		}
 	}
@@ -153,15 +156,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 				//	gRatio = userGroupRatio
 				//}
 				logContent := fmt.Sprintf("操作 %s", info.Action)
-				if len(info.PriceData.OtherRatios) > 0 {
-					var contents []string
-					for key, ra := range info.PriceData.OtherRatios {
-						if 1.0 != ra {
-							contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+				// FIXME: 临时修补,支持任务仅按次计费
+				if common.StringsContains(constant.TaskPricePatches, modelName) {
+					logContent = fmt.Sprintf("%s,按次计费", logContent)
+				} else {
+					if len(info.PriceData.OtherRatios) > 0 {
+						var contents []string
+						for key, ra := range info.PriceData.OtherRatios {
+							if 1.0 != ra {
+								contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+							}
+						}
+						if len(contents) > 0 {
+							logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
 						}
-					}
-					if len(contents) > 0 {
-						logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
 					}
 				}
 				other := make(map[string]interface{})