Jelajahi Sumber

feat: add wan2.5-i2i-preview support

feitianbubu 1 bulan lalu
induk
melakukan
344a799fcf

+ 16 - 2
relay/channel/ali/adaptor.go

@@ -47,7 +47,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		case constant.RelayModeImagesGenerations:
 			fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
 		case constant.RelayModeImagesEdits:
-			fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
+			if isWanModel(info.OriginModelName) {
+				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
+			} else {
+				fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
+			}
 		case constant.RelayModeCompletions:
 			fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
 		default:
@@ -71,6 +75,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 		req.Set("X-DashScope-Async", "enable")
 	}
 	if info.RelayMode == constant.RelayModeImagesEdits {
+		if isWanModel(info.OriginModelName) {
+			req.Set("X-DashScope-Async", "enable")
+		}
 		req.Set("Content-Type", "application/json")
 	}
 	return nil
@@ -107,6 +114,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 		}
 		return aliRequest, nil
 	} else if info.RelayMode == constant.RelayModeImagesEdits {
+		if isWanModel(info.OriginModelName) {
+			return oaiFormEdit2WanxImageEdit(c, info, request)
+		}
 		// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
 		// 如果用户使用表单,则需要解析表单数据
 		if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
@@ -161,7 +171,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case constant.RelayModeImagesGenerations:
 			err, usage = aliImageHandler(c, resp, info)
 		case constant.RelayModeImagesEdits:
-			err, usage = aliImageEditHandler(c, resp, info)
+			if isWanModel(info.OriginModelName) {
+				err, usage = aliImageHandler(c, resp, info)
+			} else {
+				err, usage = aliImageEditHandler(c, resp, info)
+			}
 		case constant.RelayModeRerank:
 			err, usage = RerankHandler(c, resp, info)
 		default:

+ 13 - 0
relay/channel/ali/dto.go

@@ -112,6 +112,19 @@ type AliImageInput struct {
 	Messages       []AliMessage `json:"messages,omitempty"`
 }
 
+type WanImageInput struct {
+	Prompt         string   `json:"prompt"`                    // 必需:文本提示词,描述生成图像中期望包含的元素和视觉特点
+	Images         []string `json:"images"`                    // 必需:图像URL数组,长度不超过2,支持HTTP/HTTPS URL或Base64编码
+	NegativePrompt string   `json:"negative_prompt,omitempty"` // 可选:反向提示词,描述不希望在画面中看到的内容
+}
+
+type WanImageParameters struct {
+	N         int     `json:"n,omitempty"`         // 生成图片数量,取值范围1-4,默认4
+	Watermark *bool   `json:"watermark,omitempty"` // 是否添加水印标识,默认false
+	Seed      int     `json:"seed,omitempty"`      // 随机数种子,取值范围[0, 2147483647]
+	Strength  float64 `json:"strength,omitempty"`  // 修改幅度 0.0-1.0,默认0.5(部分模型支持)
+}
+
 type AliRerankParameters struct {
 	TopN            *int  `json:"top_n,omitempty"`
 	ReturnDocuments *bool `json:"return_documents,omitempty"`

+ 12 - 5
relay/channel/ali/image.go

@@ -58,11 +58,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
 	return &imageRequest, nil
 }
 
-func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
-	var imageRequest AliImageRequest
-	imageRequest.Model = request.Model
-	imageRequest.ResponseFormat = request.ResponseFormat
-
+func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
 	mf := c.Request.MultipartForm
 	if mf == nil {
 		if _, err := c.MultipartForm(); err != nil {
@@ -127,7 +123,18 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
 		imageBase64s = append(imageBase64s, dataURL)
 		image.Close()
 	}
+	return imageBase64s, nil
+}
 
+func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
+	var imageRequest AliImageRequest
+	imageRequest.Model = request.Model
+	imageRequest.ResponseFormat = request.ResponseFormat
+
+	imageBase64s, err := getImageBase64sFromForm(c, "image")
+	if err != nil {
+		return nil, fmt.Errorf("get image base64s from form failed: %w", err)
+	}
 	//dto.MediaContent{}
 	mediaContents := make([]AliMediaContent, len(imageBase64s))
 	for i, b64 := range imageBase64s {

+ 39 - 0
relay/channel/ali/image_wan.go

@@ -0,0 +1,39 @@
+package ali
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+
+	"github.com/gin-gonic/gin"
+)
+
+func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
+	var err error
+	var imageRequest AliImageRequest
+	imageRequest.Model = request.Model
+	imageRequest.ResponseFormat = request.ResponseFormat
+	wanInput := WanImageInput{
+		Prompt: request.Prompt,
+	}
+
+	if err := common.UnmarshalBodyReusable(c, &wanInput); err != nil {
+		return nil, err
+	}
+	if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
+		return nil, fmt.Errorf("get image base64s from form failed: %w", err)
+	}
+	wanParams := WanImageParameters{
+		N: int(request.N),
+	}
+	imageRequest.Input = wanInput
+	imageRequest.Parameters = wanParams
+	return &imageRequest, nil
+}
+
+func isWanModel(modelName string) bool {
+	return strings.Contains(modelName, "wan")
+}