소스 검색

feat: enhance image request handling and add async support

CaIon 4 달 전
부모
커밋
7fbf9c4851
4개의 변경된 파일77개의 추가작업 그리고 25개의 파일을 삭제
  1. 3 0
      dto/openai_image.go
  2. 17 3
      relay/channel/ali/adaptor.go
  3. 16 11
      relay/channel/ali/dto.go
  4. 41 11
      relay/channel/ali/image.go

+ 3 - 0
dto/openai_image.go

@@ -25,6 +25,8 @@ type ImageRequest struct {
 	PartialImages     json.RawMessage `json:"partial_images,omitempty"`
 	// Stream            bool            `json:"stream,omitempty"`
 	Watermark *bool `json:"watermark,omitempty"`
+	// 用匿名参数接收额外参数
+	Extra map[string]json.RawMessage `json:"-"`
 }
 
 func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
@@ -72,6 +74,7 @@ func (i *ImageRequest) SetModelName(modelName string) {
 type ImageResponse struct {
 	Data    []ImageData `json:"data"`
 	Created int64       `json:"created"`
+	Extra   any         `json:"extra,omitempty"`
 }
 type ImageData struct {
 	Url           string `json:"url"`

+ 17 - 3
relay/channel/ali/adaptor.go

@@ -63,6 +63,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 	if c.GetString("plugin") != "" {
 		req.Set("X-DashScope-Plugin", c.GetString("plugin"))
 	}
+	if info.RelayMode == constant.RelayModeImagesGenerations {
+		req.Set("X-DashScope-Async", "enable")
+	}
 	return nil
 }
 
@@ -90,7 +93,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	aliRequest := oaiImage2Ali(request)
+	aliRequest, err := oaiImage2Ali(request)
+	if err != nil {
+		return nil, fmt.Errorf("convert image request failed: %w", err)
+	}
 	return aliRequest, nil
 }
 
@@ -125,8 +131,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 			return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
 		}
 	default:
-		adaptor := openai.Adaptor{}
-		return adaptor.DoResponse(c, resp, info)
+		switch info.RelayMode {
+		case constant.RelayModeImagesGenerations:
+			err, usage = aliImageHandler(c, resp, info)
+		case constant.RelayModeRerank:
+			err, usage = RerankHandler(c, resp, info)
+		default:
+			adaptor := openai.Adaptor{}
+			usage, err = adaptor.DoResponse(c, resp, info)
+		}
+		return usage, err
 	}
 }
 

+ 16 - 11
relay/channel/ali/dto.go

@@ -86,20 +86,25 @@ type AliResponse struct {
 }
 
 type AliImageRequest struct {
-	Model string `json:"model"`
-	Input struct {
-		Prompt         string `json:"prompt"`
-		NegativePrompt string `json:"negative_prompt,omitempty"`
-	} `json:"input"`
-	Parameters struct {
-		Size  string `json:"size,omitempty"`
-		N     int    `json:"n,omitempty"`
-		Steps string `json:"steps,omitempty"`
-		Scale string `json:"scale,omitempty"`
-	} `json:"parameters,omitempty"`
+	Model          string `json:"model"`
+	Input          any    `json:"input"`
+	Parameters     any    `json:"parameters,omitempty"`
 	ResponseFormat string `json:"response_format,omitempty"`
 }
 
+type AliImageParameters struct {
+	Size      string `json:"size,omitempty"`
+	N         int    `json:"n,omitempty"`
+	Steps     string `json:"steps,omitempty"`
+	Scale     string `json:"scale,omitempty"`
+	Watermark *bool  `json:"watermark,omitempty"`
+}
+
+type AliImageInput struct {
+	Prompt         string `json:"prompt"`
+	NegativePrompt string `json:"negative_prompt,omitempty"`
+}
+
 type AliRerankParameters struct {
 	TopN            *int  `json:"top_n,omitempty"`
 	ReturnDocuments *bool `json:"return_documents,omitempty"`

+ 41 - 11
relay/channel/ali/image.go

@@ -18,15 +18,41 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
+func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
 	var imageRequest AliImageRequest
-	imageRequest.Input.Prompt = request.Prompt
 	imageRequest.Model = request.Model
-	imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
-	imageRequest.Parameters.N = int(request.N)
 	imageRequest.ResponseFormat = request.ResponseFormat
 
-	return &imageRequest
+	if request.Extra != nil {
+		if val, ok := request.Extra["parameters"]; ok {
+			err := common.Unmarshal(val, &imageRequest.Parameters)
+			if err != nil {
+				return nil, fmt.Errorf("invalid parameters field: %w", err)
+			}
+		}
+		if val, ok := request.Extra["input"]; ok {
+			err := common.Unmarshal(val, &imageRequest.Input)
+			if err != nil {
+				return nil, fmt.Errorf("invalid input field: %w", err)
+			}
+		}
+	}
+
+	if imageRequest.Parameters == nil {
+		imageRequest.Parameters = AliImageParameters{
+			Size:      strings.Replace(request.Size, "x", "*", -1),
+			N:         int(request.N),
+			Watermark: request.Watermark,
+		}
+	}
+
+	if imageRequest.Input == nil {
+		imageRequest.Input = AliImageInput{
+			Prompt: request.Prompt,
+		}
+	}
+
+	return &imageRequest, nil
 }
 
 func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
@@ -52,7 +78,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
 	responseBody, err := io.ReadAll(resp.Body)
 
 	var response AliResponse
-	err = json.Unmarshal(responseBody, &response)
+	err = common.Unmarshal(responseBody, &response)
 	if err != nil {
 		common.SysLog("updateTask NewDecoder err: " + err.Error())
 		return &aliResponse, err, nil
@@ -61,8 +87,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
 	return &response, nil, responseBody
 }
 
-func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
-	waitSeconds := 3
+func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
+	waitSeconds := 5
 	step := 0
 	maxStep := 20
 
@@ -70,11 +96,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
 	var responseBody []byte
 
 	for {
+		logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
 		step++
 		rsp, err, body := updateTask(info, taskID)
 		responseBody = body
 		if err != nil {
-			return &taskResponse, responseBody, err
+			logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
+			time.Sleep(time.Duration(waitSeconds) * time.Second)
+			continue
 		}
 
 		if rsp.Output.TaskStatus == "" {
@@ -124,6 +153,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
 			RevisedPrompt: "",
 		})
 	}
+	imageResponse.Extra = response
 	return &imageResponse
 }
 
@@ -146,7 +176,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 		return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
 	}
 
-	aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
+	aliResponse, _, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponse), nil
 	}
@@ -161,7 +191,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	}
 
 	fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
-	jsonResponse, err := json.Marshal(fullTextResponse)
+	jsonResponse, err := common.Marshal(fullTextResponse)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}