|
|
@@ -21,17 +21,25 @@ import (
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
-func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
+func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
|
|
|
var imageRequest AliImageRequest
|
|
|
imageRequest.Model = request.Model
|
|
|
imageRequest.ResponseFormat = request.ResponseFormat
|
|
|
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
|
|
+ logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
|
|
|
if request.Extra != nil {
|
|
|
if val, ok := request.Extra["parameters"]; ok {
|
|
|
err := common.Unmarshal(val, &imageRequest.Parameters)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
|
|
}
|
|
|
+ } else {
|
|
|
+ // 兼容没有parameters字段的情况,从openai标准字段中提取参数
|
|
|
+ imageRequest.Parameters = AliImageParameters{
|
|
|
+ Size: strings.Replace(request.Size, "x", "*", -1),
|
|
|
+ N: int(request.N),
|
|
|
+ Watermark: request.Watermark,
|
|
|
+ }
|
|
|
}
|
|
|
if val, ok := request.Extra["input"]; ok {
|
|
|
err := common.Unmarshal(val, &imageRequest.Input)
|
|
|
@@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if imageRequest.Parameters == nil {
|
|
|
- imageRequest.Parameters = AliImageParameters{
|
|
|
- Size: strings.Replace(request.Size, "x", "*", -1),
|
|
|
- N: int(request.N),
|
|
|
- Watermark: request.Watermark,
|
|
|
+ if strings.Contains(request.Model, "z-image") {
|
|
|
+ // z-image 开启prompt_extend后,按2倍计费
|
|
|
+ if imageRequest.Parameters.PromptExtendValue() {
|
|
|
+ info.PriceData.AddOtherRatio("prompt_extend", 2)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if imageRequest.Input == nil {
|
|
|
- imageRequest.Input = AliImageInput{
|
|
|
- Prompt: request.Prompt,
|
|
|
+ // 检查n参数
|
|
|
+ if imageRequest.Parameters.N != 0 {
|
|
|
+ info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
|
|
+ }
|
|
|
+
|
|
|
+ // 同步图片模型和异步图片模型请求格式不一样
|
|
|
+ if isSync {
|
|
|
+ if imageRequest.Input == nil {
|
|
|
+ imageRequest.Input = AliImageInput{
|
|
|
+ Messages: []AliMessage{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: []AliMediaContent{
|
|
|
+ {
|
|
|
+ Text: request.Prompt,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if imageRequest.Input == nil {
|
|
|
+ imageRequest.Input = AliImageInput{
|
|
|
+ Prompt: request.Prompt,
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return &imageRequest, nil
|
|
|
}
|
|
|
-
|
|
|
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
|
|
mf := c.Request.MultipartForm
|
|
|
if mf == nil {
|
|
|
@@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
|
|
|
var taskResponse AliResponse
|
|
|
var responseBody []byte
|
|
|
|
|
|
+ time.Sleep(time.Duration(5) * time.Second)
|
|
|
+
|
|
|
for {
|
|
|
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
|
|
step++
|
|
|
@@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
|
|
|
Created: info.StartTime.Unix(),
|
|
|
}
|
|
|
|
|
|
- for _, data := range response.Output.Results {
|
|
|
- var b64Json string
|
|
|
- if responseFormat == "b64_json" {
|
|
|
- _, b64, err := service.GetImageFromUrl(data.Url)
|
|
|
- if err != nil {
|
|
|
- logger.LogError(c, "get_image_data_failed: "+err.Error())
|
|
|
- continue
|
|
|
- }
|
|
|
- b64Json = b64
|
|
|
- } else {
|
|
|
- b64Json = data.B64Image
|
|
|
- }
|
|
|
-
|
|
|
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
|
|
- Url: data.Url,
|
|
|
- B64Json: b64Json,
|
|
|
- RevisedPrompt: "",
|
|
|
- })
|
|
|
+ if len(response.Output.Results) > 0 {
|
|
|
+ imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
|
|
|
+ } else if len(response.Output.Choices) > 0 {
|
|
|
+ imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
|
|
|
}
|
|
|
- var mapResponse map[string]any
|
|
|
- _ = common.Unmarshal(originBody, &mapResponse)
|
|
|
- imageResponse.Extra = mapResponse
|
|
|
+
|
|
|
+ imageResponse.Metadata = originBody
|
|
|
return &imageResponse
|
|
|
}
|
|
|
|
|
|
-func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
|
|
+func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
|
|
responseFormat := c.GetString("response_format")
|
|
|
|
|
|
var aliTaskResponse AliResponse
|
|
|
@@ -282,66 +298,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
|
|
}
|
|
|
|
|
|
- aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(err, types.ErrorCodeBadResponse), nil
|
|
|
- }
|
|
|
-
|
|
|
- if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
|
|
- return types.WithOpenAIError(types.OpenAIError{
|
|
|
- Message: aliResponse.Output.Message,
|
|
|
- Type: "ali_error",
|
|
|
- Param: "",
|
|
|
- Code: aliResponse.Output.Code,
|
|
|
- }, resp.StatusCode), nil
|
|
|
- }
|
|
|
-
|
|
|
- fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
|
|
- jsonResponse, err := common.Marshal(fullTextResponse)
|
|
|
- if err != nil {
|
|
|
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
|
- }
|
|
|
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
|
|
- return nil, &dto.Usage{}
|
|
|
-}
|
|
|
-
|
|
|
-func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
|
|
- var aliResponse AliResponse
|
|
|
- responseBody, err := io.ReadAll(resp.Body)
|
|
|
- if err != nil {
|
|
|
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
|
|
+ var (
|
|
|
+ aliResponse *AliResponse
|
|
|
+ originRespBody []byte
|
|
|
+ )
|
|
|
+
|
|
|
+ if a.IsSyncImageModel {
|
|
|
+ aliResponse = &aliTaskResponse
|
|
|
+ originRespBody = responseBody
|
|
|
+ } else {
|
|
|
+ // 异步图片模型需要轮询任务结果
|
|
|
+ aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
|
|
+ if err != nil {
|
|
|
+ return types.NewError(err, types.ErrorCodeBadResponse), nil
|
|
|
+ }
|
|
|
+ if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
|
|
+ return types.WithOpenAIError(types.OpenAIError{
|
|
|
+ Message: aliResponse.Output.Message,
|
|
|
+ Type: "ali_error",
|
|
|
+ Param: "",
|
|
|
+ Code: aliResponse.Output.Code,
|
|
|
+ }, resp.StatusCode), nil
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- service.CloseResponseBodyGracefully(resp)
|
|
|
- err = common.Unmarshal(responseBody, &aliResponse)
|
|
|
- if err != nil {
|
|
|
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
|
|
+ //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
|
|
|
+ if a.IsSyncImageModel {
|
|
|
+ logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
|
|
|
+ } else {
|
|
|
+ logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
|
|
|
}
|
|
|
|
|
|
- if aliResponse.Message != "" {
|
|
|
- logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
|
|
- return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
|
|
+ imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
|
|
+ // 可能生成多张图片,修正计费数量n
|
|
|
+ if aliResponse.Usage.ImageCount != 0 {
|
|
|
+ info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
|
|
+ } else if len(imageResponses.Data) != 0 {
|
|
|
+ info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
|
|
|
}
|
|
|
- var fullTextResponse dto.ImageResponse
|
|
|
- if len(aliResponse.Output.Choices) > 0 {
|
|
|
- fullTextResponse = dto.ImageResponse{
|
|
|
- Created: info.StartTime.Unix(),
|
|
|
- Data: []dto.ImageData{
|
|
|
- {
|
|
|
- Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
|
|
- B64Json: "",
|
|
|
- },
|
|
|
- },
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- var mapResponse map[string]any
|
|
|
- _ = common.Unmarshal(responseBody, &mapResponse)
|
|
|
- fullTextResponse.Extra = mapResponse
|
|
|
- jsonResponse, err := common.Marshal(fullTextResponse)
|
|
|
+ jsonResponse, err := common.Marshal(imageResponses)
|
|
|
if err != nil {
|
|
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
|
|
}
|
|
|
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
|
|
+
|
|
|
return nil, &dto.Usage{}
|
|
|
}
|