|
|
@@ -1,9 +1,12 @@
|
|
|
package ali
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
+ "encoding/base64"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "mime/multipart"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/dto"
|
|
|
@@ -21,7 +24,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
var imageRequest AliImageRequest
|
|
|
imageRequest.Model = request.Model
|
|
|
imageRequest.ResponseFormat = request.ResponseFormat
|
|
|
-
|
|
|
+ logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
|
|
if request.Extra != nil {
|
|
|
if val, ok := request.Extra["parameters"]; ok {
|
|
|
err := common.Unmarshal(val, &imageRequest.Parameters)
|
|
|
@@ -54,6 +57,100 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
return &imageRequest, nil
|
|
|
}
|
|
|
|
|
|
+func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
+ var imageRequest AliImageRequest
|
|
|
+ imageRequest.Model = request.Model
|
|
|
+ imageRequest.ResponseFormat = request.ResponseFormat
|
|
|
+
|
|
|
+ mf := c.Request.MultipartForm
|
|
|
+ if mf == nil {
|
|
|
+ if _, err := c.MultipartForm(); err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
|
|
+ }
|
|
|
+ mf = c.Request.MultipartForm
|
|
|
+ }
|
|
|
+
|
|
|
+ var imageFiles []*multipart.FileHeader
|
|
|
+ var exists bool
|
|
|
+
|
|
|
+ // First check for standard "image" field
|
|
|
+ if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
|
|
+ // If not found, check for "image[]" field
|
|
|
+ if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
|
|
+ // If still not found, iterate through all fields to find any that start with "image["
|
|
|
+ foundArrayImages := false
|
|
|
+ for fieldName, files := range mf.File {
|
|
|
+ if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
|
|
+ foundArrayImages = true
|
|
|
+ imageFiles = append(imageFiles, files...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // If no image fields found at all
|
|
|
+ if !foundArrayImages && (len(imageFiles) == 0) {
|
|
|
+ return nil, errors.New("image is required")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(imageFiles) == 0 {
|
|
|
+ return nil, errors.New("image is required")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(imageFiles) > 1 {
|
|
|
+ return nil, errors.New("only one image is supported for qwen edit")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取base64编码的图片
|
|
|
+ var imageBase64s []string
|
|
|
+ for _, file := range imageFiles {
|
|
|
+ image, err := file.Open()
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.New("failed to open image file")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 读取文件内容
|
|
|
+ imageData, err := io.ReadAll(image)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.New("failed to read image file")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取MIME类型
|
|
|
+ mimeType := http.DetectContentType(imageData)
|
|
|
+
|
|
|
+ // 编码为base64
|
|
|
+ base64Data := base64.StdEncoding.EncodeToString(imageData)
|
|
|
+
|
|
|
+ // 构造data URL格式
|
|
|
+ dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
|
|
+ imageBase64s = append(imageBase64s, dataURL)
|
|
|
+ image.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ //dto.MediaContent{}
|
|
|
+ mediaContents := make([]AliMediaContent, len(imageBase64s))
|
|
|
+ for i, b64 := range imageBase64s {
|
|
|
+ mediaContents[i] = AliMediaContent{
|
|
|
+ Image: b64,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ mediaContents = append(mediaContents, AliMediaContent{
|
|
|
+ Text: request.Prompt,
|
|
|
+ })
|
|
|
+ imageRequest.Input = AliImageInput{
|
|
|
+ Messages: []AliMessage{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: mediaContents,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ imageRequest.Parameters = AliImageParameters{
|
|
|
+ Watermark: request.Watermark,
|
|
|
+ }
|
|
|
+ return &imageRequest, nil
|
|
|
+}
|
|
|
+
|
|
|
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
|
|
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
|
|
|
|
|
@@ -196,8 +293,47 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
if err != nil {
|
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
|
}
|
|
|
- c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
- c.Writer.WriteHeader(resp.StatusCode)
|
|
|
- c.Writer.Write(jsonResponse)
|
|
|
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
|
|
+ return nil, &dto.Usage{}
|
|
|
+}
|
|
|
+
|
|
|
+func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
|
|
+ var aliResponse AliResponse
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ service.CloseResponseBodyGracefully(resp)
|
|
|
+ err = common.Unmarshal(responseBody, &aliResponse)
|
|
|
+ if err != nil {
|
|
|
+ return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if aliResponse.Message != "" {
|
|
|
+ logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
|
|
+ return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
|
|
+ }
|
|
|
+ var fullTextResponse dto.ImageResponse
|
|
|
+ if len(aliResponse.Output.Choices) > 0 {
|
|
|
+ fullTextResponse = dto.ImageResponse{
|
|
|
+ Created: info.StartTime.Unix(),
|
|
|
+ Data: []dto.ImageData{
|
|
|
+ {
|
|
|
+ Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
|
|
+ B64Json: "",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var mapResponse map[string]any
|
|
|
+ _ = common.Unmarshal(responseBody, &mapResponse)
|
|
|
+ fullTextResponse.Extra = mapResponse
|
|
|
+ jsonResponse, err := common.Marshal(fullTextResponse)
|
|
|
+ if err != nil {
|
|
|
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
|
+ }
|
|
|
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
|
|
return nil, &dto.Usage{}
|
|
|
}
|