Просмотр исходного кода

Merge pull request #2090 from feitianbubu/pr/doubao-image-edit

修复豆包图像编辑(图生图)功能
IcedTangerine 2 месяцев назад
Родитель
Сommit
032f159509
4 измененных файлов с 172 добавлено и 107 удалено
  1. 60 0
      common/gin.go
  2. 2 1
      dto/openai_image.go
  3. 3 1
      middleware/distributor.go
  4. 107 105
      relay/channel/volcengine/adaptor.go

+ 60 - 0
common/gin.go

@@ -2,9 +2,11 @@ package common
 
 import (
 	"bytes"
+	"encoding/json"
 	"io"
 	"mime/multipart"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 
@@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 	contentType := c.Request.Header.Get("Content-Type")
 	if strings.HasPrefix(contentType, "application/json") {
 		err = Unmarshal(requestBody, &v)
+	} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
+		err = parseFormData(requestBody, &v)
+	} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
+		err = parseMultipartFormData(c, requestBody, &v)
 	} else {
 		// skip for now
 		// TODO: someday non json request have variant model, we will need to implementation this
@@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	return form, nil
 }
+
+func parseFormData(data []byte, v any) error {
+	values, err := url.ParseQuery(string(data))
+	if err != nil {
+		return err
+	}
+	formMap := make(map[string]any)
+	for key, vals := range values {
+		if len(vals) == 1 {
+			formMap[key] = vals[0]
+		} else {
+			formMap[key] = vals
+		}
+	}
+	jsonData, err := json.Marshal(formMap)
+	if err != nil {
+		return err
+	}
+
+	return json.Unmarshal(jsonData, v)
+}
+
+func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
+	contentType := c.Request.Header.Get("Content-Type")
+	boundary := ""
+	if idx := strings.Index(contentType, "boundary="); idx != -1 {
+		boundary = contentType[idx+9:]
+	}
+
+	if boundary == "" {
+		return json.Unmarshal(data, v) // Fallback to JSON
+	}
+
+	reader := multipart.NewReader(bytes.NewReader(data), boundary)
+	form, err := reader.ReadForm(32 << 20) // 32 MB max memory
+	if err != nil {
+		return err
+	}
+	defer form.RemoveAll()
+	formMap := make(map[string]any)
+	for key, vals := range form.Value {
+		if len(vals) == 1 {
+			formMap[key] = vals[0]
+		} else {
+			formMap[key] = vals
+		}
+	}
+	jsonData, err := json.Marshal(formMap)
+	if err != nil {
+		return err
+	}
+
+	return json.Unmarshal(jsonData, v)
+}

+ 2 - 1
dto/openai_image.go

@@ -27,7 +27,8 @@ type ImageRequest struct {
 	OutputCompression json.RawMessage `json:"output_compression,omitempty"`
 	PartialImages     json.RawMessage `json:"partial_images,omitempty"`
 	// Stream            bool            `json:"stream,omitempty"`
-	Watermark *bool `json:"watermark,omitempty"`
+	Watermark *bool           `json:"watermark,omitempty"`
+	Image     json.RawMessage `json:"image,omitempty"`
 	// 用匿名参数接收额外参数
 	Extra map[string]json.RawMessage `json:"-"`
 }

+ 3 - 1
middleware/distributor.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"slices"
 	"strconv"
 	"strings"
 	"time"
@@ -245,7 +246,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
 		//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
-		if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
+		contentType := c.ContentType()
+		if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
 			modelRequest.Model = c.PostForm("model")
 		}
 	}

+ 107 - 105
relay/channel/volcengine/adaptor.go

@@ -6,9 +6,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"mime/multipart"
 	"net/http"
-	"net/textproto"
 	"path/filepath"
 	"strings"
 
@@ -104,106 +102,107 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 	switch info.RelayMode {
 	case constant.RelayModeImagesGenerations:
 		return request, nil
-	case constant.RelayModeImagesEdits:
-
-		var requestBody bytes.Buffer
-		writer := multipart.NewWriter(&requestBody)
-
-		writer.WriteField("model", request.Model)
-
-		formData := c.Request.PostForm
-		for key, values := range formData {
-			if key == "model" {
-				continue
-			}
-			for _, value := range values {
-				writer.WriteField(key, value)
-			}
-		}
-
-		if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
-			return nil, errors.New("failed to parse multipart form")
-		}
-
-		if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
-			var imageFiles []*multipart.FileHeader
-			var exists bool
-
-			if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
-				if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
-					foundArrayImages := false
-					for fieldName, files := range c.Request.MultipartForm.File {
-						if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
-							foundArrayImages = true
-							for _, file := range files {
-								imageFiles = append(imageFiles, file)
-							}
-						}
-					}
-
-					if !foundArrayImages && (len(imageFiles) == 0) {
-						return nil, errors.New("image is required")
-					}
-				}
-			}
-
-			for i, fileHeader := range imageFiles {
-				file, err := fileHeader.Open()
-				if err != nil {
-					return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
-				}
-				defer file.Close()
-
-				fieldName := "image"
-				if len(imageFiles) > 1 {
-					fieldName = "image[]"
-				}
-
-				mimeType := detectImageMimeType(fileHeader.Filename)
-
-				h := make(textproto.MIMEHeader)
-				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
-				h.Set("Content-Type", mimeType)
-
-				part, err := writer.CreatePart(h)
-				if err != nil {
-					return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
-				}
-
-				if _, err := io.Copy(part, file); err != nil {
-					return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
-				}
-			}
-
-			if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
-				maskFile, err := maskFiles[0].Open()
-				if err != nil {
-					return nil, errors.New("failed to open mask file")
-				}
-				defer maskFile.Close()
-
-				mimeType := detectImageMimeType(maskFiles[0].Filename)
-
-				h := make(textproto.MIMEHeader)
-				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
-				h.Set("Content-Type", mimeType)
-
-				maskPart, err := writer.CreatePart(h)
-				if err != nil {
-					return nil, errors.New("create form file failed for mask")
-				}
-
-				if _, err := io.Copy(maskPart, maskFile); err != nil {
-					return nil, errors.New("copy mask file failed")
-				}
-			}
-		} else {
-			return nil, errors.New("no multipart form data found")
-		}
-
-		writer.Close()
-		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
-		return bytes.NewReader(requestBody.Bytes()), nil
+	// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
+	//case constant.RelayModeImagesEdits:
+	//
+	//	var requestBody bytes.Buffer
+	//	writer := multipart.NewWriter(&requestBody)
+	//
+	//	writer.WriteField("model", request.Model)
+	//
+	//	formData := c.Request.PostForm
+	//	for key, values := range formData {
+	//		if key == "model" {
+	//			continue
+	//		}
+	//		for _, value := range values {
+	//			writer.WriteField(key, value)
+	//		}
+	//	}
+	//
+	//	if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
+	//		return nil, errors.New("failed to parse multipart form")
+	//	}
+	//
+	//	if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+	//		var imageFiles []*multipart.FileHeader
+	//		var exists bool
+	//
+	//		if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+	//			if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+	//				foundArrayImages := false
+	//				for fieldName, files := range c.Request.MultipartForm.File {
+	//					if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+	//						foundArrayImages = true
+	//						for _, file := range files {
+	//							imageFiles = append(imageFiles, file)
+	//						}
+	//					}
+	//				}
+	//
+	//				if !foundArrayImages && (len(imageFiles) == 0) {
+	//					return nil, errors.New("image is required")
+	//				}
+	//			}
+	//		}
+	//
+	//		for i, fileHeader := range imageFiles {
+	//			file, err := fileHeader.Open()
+	//			if err != nil {
+	//				return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+	//			}
+	//			defer file.Close()
+	//
+	//			fieldName := "image"
+	//			if len(imageFiles) > 1 {
+	//				fieldName = "image[]"
+	//			}
+	//
+	//			mimeType := detectImageMimeType(fileHeader.Filename)
+	//
+	//			h := make(textproto.MIMEHeader)
+	//			h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+	//			h.Set("Content-Type", mimeType)
+	//
+	//			part, err := writer.CreatePart(h)
+	//			if err != nil {
+	//				return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+	//			}
+	//
+	//			if _, err := io.Copy(part, file); err != nil {
+	//				return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+	//			}
+	//		}
+	//
+	//		if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+	//			maskFile, err := maskFiles[0].Open()
+	//			if err != nil {
+	//				return nil, errors.New("failed to open mask file")
+	//			}
+	//			defer maskFile.Close()
+	//
+	//			mimeType := detectImageMimeType(maskFiles[0].Filename)
+	//
+	//			h := make(textproto.MIMEHeader)
+	//			h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+	//			h.Set("Content-Type", mimeType)
+	//
+	//			maskPart, err := writer.CreatePart(h)
+	//			if err != nil {
+	//				return nil, errors.New("create form file failed for mask")
+	//			}
+	//
+	//			if _, err := io.Copy(maskPart, maskFile); err != nil {
+	//				return nil, errors.New("copy mask file failed")
+	//			}
+	//		}
+	//	} else {
+	//		return nil, errors.New("no multipart form data found")
+	//	}
+	//
+	//	writer.Close()
+	//	c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+	//	return bytes.NewReader(requestBody.Bytes()), nil
 
 	default:
 		return request, nil
@@ -251,10 +250,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 			return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
 		case constant.RelayModeEmbeddings:
 			return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
-		case constant.RelayModeImagesGenerations:
+		//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
+		case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
 			return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
-		case constant.RelayModeImagesEdits:
-			return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
+		//case constant.RelayModeImagesEdits:
+		//	return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
 		case constant.RelayModeRerank:
 			return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
 		case constant.RelayModeAudioSpeech:
@@ -278,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 		}
 		req.Set("Content-Type", "application/json")
 		return nil
+	} else if info.RelayMode == constant.RelayModeImagesEdits {
+		req.Set("Content-Type", gin.MIMEJSON)
 	}
 
 	req.Set("Authorization", "Bearer "+info.ApiKey)