Browse Source

Merge branch 'feat_images' of github.com:bddiudiu/new-api into bddiudiu-feat_images

creamlike1024 6 months ago
parent
commit
7577ec1ac4
3 changed files with 156 additions and 2 deletions
  1. 1 0
      dto/dalle.go
  2. 148 2
      relay/channel/volcengine/adaptor.go
  3. 7 0
      relay/image_handler.go

+ 1 - 0
dto/dalle.go

@@ -15,6 +15,7 @@ type ImageRequest struct {
 	Background     string          `json:"background,omitempty"`
 	Moderation     string          `json:"moderation,omitempty"`
 	OutputFormat   string          `json:"output_format,omitempty"`
+	Watermark      *bool           `json:"watermark,omitempty"`
 }
 
 type ImageResponse struct {

+ 148 - 2
relay/channel/volcengine/adaptor.go

@@ -1,15 +1,19 @@
 package volcengine
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
+	"net/textproto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
+	"path/filepath"
 	"strings"
 
 	"github.com/gin-gonic/gin"
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	switch info.RelayMode {
+	case constant.RelayModeImagesEdits:
+
+		var requestBody bytes.Buffer
+		writer := multipart.NewWriter(&requestBody)
+
+		writer.WriteField("model", request.Model)
+		// 获取所有表单字段
+		formData := c.Request.PostForm
+		// 遍历表单字段并打印输出
+		for key, values := range formData {
+			if key == "model" {
+				continue
+			}
+			for _, value := range values {
+				writer.WriteField(key, value)
+			}
+		}
+
+		// Parse the multipart form to handle both single image and multiple images
+		if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+			return nil, errors.New("failed to parse multipart form")
+		}
+
+		if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+			// Check if "image" field exists in any form, including array notation
+			var imageFiles []*multipart.FileHeader
+			var exists bool
+
+			// First check for standard "image" field
+			if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+				// If not found, check for "image[]" field
+				if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+					// If still not found, iterate through all fields to find any that start with "image["
+					foundArrayImages := false
+					for fieldName, files := range c.Request.MultipartForm.File {
+						if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+							foundArrayImages = true
+							for _, file := range files {
+								imageFiles = append(imageFiles, file)
+							}
+						}
+					}
+
+					// If no image fields found at all
+					if !foundArrayImages && (len(imageFiles) == 0) {
+						return nil, errors.New("image is required")
+					}
+				}
+			}
+
+			// Process all image files
+			for i, fileHeader := range imageFiles {
+				file, err := fileHeader.Open()
+				if err != nil {
+					return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+				}
+				defer file.Close()
+
+				// If multiple images, use image[] as the field name
+				fieldName := "image"
+				if len(imageFiles) > 1 {
+					fieldName = "image[]"
+				}
+
+				// Determine MIME type based on file extension
+				mimeType := detectImageMimeType(fileHeader.Filename)
+
+				// Create a form file with the appropriate content type
+				h := make(textproto.MIMEHeader)
+				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+				h.Set("Content-Type", mimeType)
+
+				part, err := writer.CreatePart(h)
+				if err != nil {
+					return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+				}
+
+				if _, err := io.Copy(part, file); err != nil {
+					return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+				}
+			}
+
+			// Handle mask file if present
+			if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+				maskFile, err := maskFiles[0].Open()
+				if err != nil {
+					return nil, errors.New("failed to open mask file")
+				}
+				defer maskFile.Close()
+
+				// Determine MIME type for mask file
+				mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+				// Create a form file with the appropriate content type
+				h := make(textproto.MIMEHeader)
+				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+				h.Set("Content-Type", mimeType)
+
+				maskPart, err := writer.CreatePart(h)
+				if err != nil {
+					return nil, errors.New("create form file failed for mask")
+				}
+
+				if _, err := io.Copy(maskPart, maskFile); err != nil {
+					return nil, errors.New("copy mask file failed")
+				}
+			}
+		} else {
+			return nil, errors.New("no multipart form data found")
+		}
+
+		// 关闭 multipart 编写器以设置分界线
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return bytes.NewReader(requestBody.Bytes()), nil
+
+	default:
+		return request, nil
+	}
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+	ext := strings.ToLower(filepath.Ext(filename))
+	switch ext {
+	case ".jpg", ".jpeg":
+		return "image/jpeg"
+	case ".png":
+		return "image/png"
+	case ".webp":
+		return "image/webp"
+	default:
+		// Try to detect from extension if possible
+		if strings.HasPrefix(ext, ".jp") {
+			return "image/jpeg"
+		}
+		// Default to png as a fallback
+		return "image/png"
+	}
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
 	case constant.RelayModeEmbeddings:
 		return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+	case constant.RelayModeImagesGenerations:
+		return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
 	default:
 	}
 	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		}
 	case constant.RelayModeEmbeddings:
 		err, usage = openai.OpenaiHandler(c, resp, info)
+	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+		err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
 	}
 	return
 }

+ 7 - 0
relay/image_handler.go

@@ -17,6 +17,8 @@ import (
 	"one-api/setting"
 	"strings"
 
+	"one-api/relay/constant"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -44,6 +46,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 		if imageRequest.N == 0 {
 			imageRequest.N = 1
 		}
+
+		if info.ApiType == constant.APITypeVolcEngine {
+			watermark := formData.Has("watermark")
+			imageRequest.Watermark = &watermark
+		}
 	default:
 		err := common.UnmarshalBodyReusable(c, imageRequest)
 		if err != nil {