Răsfoiți Sursa

feat: enhance Ali video request processing with resolution mapping and size validation

CaIon 2 luni în urmă
părinte
comite
424325162e
1 a modificat fișierele cu 106 adăugiri și 20 ștergeri
  1. 106 20
      relay/channel/task/ali/adaptor.go

+ 106 - 20
relay/channel/task/ali/adaptor.go

@@ -15,6 +15,7 @@ import (
 	"github.com/QuantumNous/new-api/relay/channel"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
+	"github.com/samber/lo"
 
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
@@ -155,8 +156,51 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	return bytes.NewReader(bodyBytes), nil
 }
 
-func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
-	otherRatios := map[string]map[string]float64{
+var (
+	size480p = []string{
+		"832*480",
+		"480*832",
+		"624*624",
+	}
+	size720p = []string{
+		"1280*720",
+		"720*1280",
+		"960*960",
+		"1088*832",
+		"832*1088",
+	}
+	size1080p = []string{
+		"1920*1080",
+		"1080*1920",
+		"1440*1440",
+		"1632*1248",
+		"1248*1632",
+	}
+)
+
+func sizeToResolution(size string) (string, error) {
+	if lo.Contains(size480p, size) {
+		return "480P", nil
+	} else if lo.Contains(size720p, size) {
+		return "720P", nil
+	} else if lo.Contains(size1080p, size) {
+		return "1080P", nil
+	}
+	return "", fmt.Errorf("invalid size: %s", size)
+}
+
+func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
+	otherRatios := make(map[string]float64)
+	aliRatios := map[string]map[string]float64{
+		"wan2.5-t2v-preview": {
+			"480P":  1,
+			"720P":  2,
+			"1080P": 1 / 0.3,
+		},
+		"wan2.2-t2v-plus": {
+			"480P":  1,
+			"1080P": 0.7 / 0.14,
+		},
 		"wan2.5-i2v-preview": {
 			"480P":  1,
 			"720P":  2,
@@ -180,6 +224,30 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 			"720P": 0.9 / 0.5,
 		},
 	}
+	var resolution string
+
+	// size match
+	if aliReq.Parameters.Size != "" {
+		toResolution, err := sizeToResolution(aliReq.Parameters.Size)
+		if err != nil {
+			return nil, err
+		}
+		resolution = toResolution
+	} else {
+		resolution = strings.ToUpper(aliReq.Parameters.Resolution)
+		if !strings.HasSuffix(resolution, "P") {
+			resolution = resolution + "P"
+		}
+	}
+	if otherRatio, ok := aliRatios[aliReq.Model]; ok {
+		if ratio, ok := otherRatio[resolution]; ok {
+			otherRatios[fmt.Sprintf("resolution-%s", resolution)] = ratio
+		}
+	}
+	return otherRatios, nil
+}
+
+func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
 	aliReq := &AliVideoRequest{
 		Model: req.Model,
 		Input: AliVideoInput{
@@ -194,22 +262,40 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 
 	// 处理分辨率映射
 	if req.Size != "" {
-		resolution := strings.ToUpper(req.Size)
-		// 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
-		if !strings.HasSuffix(resolution, "P") {
-			resolution = resolution + "P"
+		// text to video size must be contained *
+		if strings.Contains(req.Model, "t2v") && !strings.Contains(req.Size, "*") {
+			return nil, fmt.Errorf("invalid size: %s, example: %s", req.Size, "1920*1080")
+		}
+		if strings.Contains(req.Size, "*") {
+			aliReq.Parameters.Size = req.Size
+		} else {
+			resolution := strings.ToUpper(req.Size)
+			// 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
+			if !strings.HasSuffix(resolution, "P") {
+				resolution = resolution + "P"
+			}
+			aliReq.Parameters.Resolution = resolution
 		}
-		aliReq.Parameters.Resolution = resolution
 	} else {
 		// 根据模型设置默认分辨率
-		if strings.HasPrefix(req.Model, "wan2.5") {
-			aliReq.Parameters.Resolution = "1080P"
-		} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
-			aliReq.Parameters.Resolution = "720P"
-		} else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
-			aliReq.Parameters.Resolution = "1080P"
+		if strings.Contains(req.Model, "t2v") { // image to video
+			if strings.HasPrefix(req.Model, "wan2.5") {
+				aliReq.Parameters.Size = "1920*1080"
+			} else if strings.HasPrefix(req.Model, "wan2.2") {
+				aliReq.Parameters.Size = "1920*1080"
+			} else {
+				aliReq.Parameters.Size = "1280*720"
+			}
 		} else {
-			aliReq.Parameters.Resolution = "720P"
+			if strings.HasPrefix(req.Model, "wan2.5") {
+				aliReq.Parameters.Resolution = "1080P"
+			} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
+				aliReq.Parameters.Resolution = "720P"
+			} else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
+				aliReq.Parameters.Resolution = "1080P"
+			} else {
+				aliReq.Parameters.Resolution = "720P"
+			}
 		}
 	}
 
@@ -247,13 +333,13 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 		"seconds": float64(aliReq.Parameters.Duration),
 	}
 
-	if otherRatio, ok := otherRatios[req.Model]; ok {
-		if ratio, ok := otherRatio[aliReq.Parameters.Resolution]; ok {
-			info.PriceData.OtherRatios[fmt.Sprintf("resolution-%s", aliReq.Parameters.Resolution)] = ratio
-		}
+	ratios, err := ProcessAliOtherRatios(aliReq)
+	if err != nil {
+		return nil, err
+	}
+	for s, f := range ratios {
+		info.PriceData.OtherRatios[s] = f
 	}
-
-	// println(fmt.Sprintf("other ratios: %v", info.PriceData.OtherRatios))
 
 	return aliReq, nil
 }