Forráskód Böngészése

Merge pull request #2006 from xyfacai/feat/sora-price

feat: sora 增加参数校验与计费
Calcium-Ion 2 hónapja
szülő
commit
eab768b4a0

+ 29 - 0
relay/common/relay_utils.go

@@ -10,6 +10,7 @@ import (
 	"strings"
 
 	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
 )
 
 type HasPrompt interface {
@@ -156,6 +157,34 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
 		action = constant.TaskActionGenerate
 	}
 	info.Action = action
+	model := form.Value["model"][0]
+	if strings.HasPrefix(model, "sora-2") {
+		seconds := 4
+		size := "720x1280"
+		if ss, ok := form.Value["seconds"]; ok {
+			sInt := common.String2Int(ss[0])
+			if sInt > seconds {
+				seconds = common.String2Int(ss[0])
+			}
+		}
+		if s, ok := form.Value["size"]; ok {
+			size = s[0]
+		}
+
+		if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
+			return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
+		}
+		if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
+			return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
+		}
+		info.PriceData.OtherRatios = map[string]float64{
+			"seconds": float64(seconds),
+			"size":    1,
+		}
+		if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
+			info.PriceData.OtherRatios["size"] = 1.666667
+		}
+	}
 
 	return nil
 }

+ 1 - 1
relay/helper/price.go

@@ -114,7 +114,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
 	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
 	// 如果没有配置价格,则使用默认价格
 	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
+		defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[info.OriginModelName]
 		if !ok {
 			modelPrice = 0.1
 		} else {

+ 19 - 1
relay/relay_task.go

@@ -54,7 +54,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	}
 	modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
 	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
+		defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
 		if !ok {
 			modelPrice = 0.1
 		} else {
@@ -71,6 +71,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	} else {
 		ratio = modelPrice * groupRatio
 	}
+	if len(info.PriceData.OtherRatios) > 0 {
+		for _, ra := range info.PriceData.OtherRatios {
+			if 1.0 != ra {
+				ratio *= ra
+			}
+		}
+	}
 	userQuota, err := model.GetUserQuota(info.UserId, false)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -144,6 +151,17 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 					gRatio = userGroupRatio
 				}
 				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
+				if len(info.PriceData.OtherRatios) > 0 {
+					var contents []string
+					for key, ra := range info.PriceData.OtherRatios {
+						if 1.0 != ra {
+							contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+						}
+					}
+					if len(contents) > 0 {
+						logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
+					}
+				}
 				other := make(map[string]interface{})
 				other["model_price"] = modelPrice
 				other["group_ratio"] = groupRatio

+ 6 - 0
setting/ratio_setting/model_ratio.go

@@ -290,6 +290,8 @@ var defaultModelPrice = map[string]float64{
 	"mj_upscale":              0.05,
 	"swap_face":               0.05,
 	"mj_upload":               0.05,
+	"sora-2":                  0.3,
+	"sora-2-pro":              0.5,
 }
 
 var defaultAudioRatio = map[string]float64{
@@ -452,6 +454,10 @@ func GetDefaultModelRatioMap() map[string]float64 {
 	return defaultModelRatio
 }
 
+func GetDefaultModelPriceMap() map[string]float64 {
+	return defaultModelPrice
+}
+
 func GetDefaultImageRatioMap() map[string]float64 {
 	return defaultImageRatio
 }

+ 1 - 0
types/price_data.go

@@ -17,6 +17,7 @@ type PriceData struct {
 	ImageRatio             float64
 	AudioRatio             float64
 	AudioCompletionRatio   float64
+	OtherRatios            map[string]float64
 	UsePrice               bool
 	ShouldPreConsumedQuota int
 	GroupRatioInfo         GroupRatioInfo