|
|
@@ -18,15 +18,41 @@ import (
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
-func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
|
|
+func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
var imageRequest AliImageRequest
|
|
|
- imageRequest.Input.Prompt = request.Prompt
|
|
|
imageRequest.Model = request.Model
|
|
|
- imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
|
|
- imageRequest.Parameters.N = int(request.N)
|
|
|
imageRequest.ResponseFormat = request.ResponseFormat
|
|
|
|
|
|
- return &imageRequest
|
|
|
+ if request.Extra != nil {
|
|
|
+ if val, ok := request.Extra["parameters"]; ok {
|
|
|
+ err := common.Unmarshal(val, &imageRequest.Parameters)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("invalid parameters field: %w", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if val, ok := request.Extra["input"]; ok {
|
|
|
+ err := common.Unmarshal(val, &imageRequest.Input)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("invalid input field: %w", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if imageRequest.Parameters == nil {
|
|
|
+ imageRequest.Parameters = AliImageParameters{
|
|
|
+ Size: strings.Replace(request.Size, "x", "*", -1),
|
|
|
+ N: int(request.N),
|
|
|
+ Watermark: request.Watermark,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if imageRequest.Input == nil {
|
|
|
+ imageRequest.Input = AliImageInput{
|
|
|
+ Prompt: request.Prompt,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return &imageRequest, nil
|
|
|
}
|
|
|
|
|
|
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
|
|
@@ -52,7 +78,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
|
|
|
var response AliResponse
|
|
|
- err = json.Unmarshal(responseBody, &response)
|
|
|
+ err = common.Unmarshal(responseBody, &response)
|
|
|
if err != nil {
|
|
|
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
|
|
return &aliResponse, err, nil
|
|
|
@@ -61,8 +87,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|
|
return &response, nil, responseBody
|
|
|
}
|
|
|
|
|
|
-func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
|
|
- waitSeconds := 3
|
|
|
+func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
|
|
+ waitSeconds := 5
|
|
|
step := 0
|
|
|
maxStep := 20
|
|
|
|
|
|
@@ -70,11 +96,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
|
|
var responseBody []byte
|
|
|
|
|
|
for {
|
|
|
+ logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
|
|
step++
|
|
|
rsp, err, body := updateTask(info, taskID)
|
|
|
responseBody = body
|
|
|
if err != nil {
|
|
|
- return &taskResponse, responseBody, err
|
|
|
+ logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
|
|
|
+ time.Sleep(time.Duration(waitSeconds) * time.Second)
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
if rsp.Output.TaskStatus == "" {
|
|
|
@@ -124,6 +153,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
|
|
RevisedPrompt: "",
|
|
|
})
|
|
|
}
|
|
|
+ imageResponse.Extra = response
|
|
|
return &imageResponse
|
|
|
}
|
|
|
|
|
|
@@ -146,7 +176,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
|
|
}
|
|
|
|
|
|
- aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
|
|
+ aliResponse, _, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
|
|
if err != nil {
|
|
|
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
|
|
}
|
|
|
@@ -161,7 +191,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
}
|
|
|
|
|
|
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
|
|
- jsonResponse, err := json.Marshal(fullTextResponse)
|
|
|
+ jsonResponse, err := common.Marshal(fullTextResponse)
|
|
|
if err != nil {
|
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
|
}
|