image.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package ali
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/service"
  12. "one-api/types"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
  18. var imageRequest AliImageRequest
  19. imageRequest.Input.Prompt = request.Prompt
  20. imageRequest.Model = request.Model
  21. imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
  22. imageRequest.Parameters.N = request.N
  23. imageRequest.ResponseFormat = request.ResponseFormat
  24. return &imageRequest
  25. }
  26. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  27. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
  28. var aliResponse AliResponse
  29. req, err := http.NewRequest("GET", url, nil)
  30. if err != nil {
  31. return &aliResponse, err, nil
  32. }
  33. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  34. client := &http.Client{}
  35. resp, err := client.Do(req)
  36. if err != nil {
  37. common.SysError("updateTask client.Do err: " + err.Error())
  38. return &aliResponse, err, nil
  39. }
  40. defer resp.Body.Close()
  41. responseBody, err := io.ReadAll(resp.Body)
  42. var response AliResponse
  43. err = json.Unmarshal(responseBody, &response)
  44. if err != nil {
  45. common.SysError("updateTask NewDecoder err: " + err.Error())
  46. return &aliResponse, err, nil
  47. }
  48. return &response, nil, responseBody
  49. }
  50. func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  51. waitSeconds := 3
  52. step := 0
  53. maxStep := 20
  54. var taskResponse AliResponse
  55. var responseBody []byte
  56. for {
  57. step++
  58. rsp, err, body := updateTask(info, taskID)
  59. responseBody = body
  60. if err != nil {
  61. return &taskResponse, responseBody, err
  62. }
  63. if rsp.Output.TaskStatus == "" {
  64. return &taskResponse, responseBody, nil
  65. }
  66. switch rsp.Output.TaskStatus {
  67. case "FAILED":
  68. fallthrough
  69. case "CANCELED":
  70. fallthrough
  71. case "SUCCEEDED":
  72. fallthrough
  73. case "UNKNOWN":
  74. return rsp, responseBody, nil
  75. }
  76. if step >= maxStep {
  77. break
  78. }
  79. time.Sleep(time.Duration(waitSeconds) * time.Second)
  80. }
  81. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  82. }
  83. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  84. imageResponse := dto.ImageResponse{
  85. Created: info.StartTime.Unix(),
  86. }
  87. for _, data := range response.Output.Results {
  88. var b64Json string
  89. if responseFormat == "b64_json" {
  90. _, b64, err := service.GetImageFromUrl(data.Url)
  91. if err != nil {
  92. common.LogError(c, "get_image_data_failed: "+err.Error())
  93. continue
  94. }
  95. b64Json = b64
  96. } else {
  97. b64Json = data.B64Image
  98. }
  99. imageResponse.Data = append(imageResponse.Data, dto.ImageData{
  100. Url: data.Url,
  101. B64Json: b64Json,
  102. RevisedPrompt: "",
  103. })
  104. }
  105. return &imageResponse
  106. }
  107. func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  108. responseFormat := c.GetString("response_format")
  109. var aliTaskResponse AliResponse
  110. responseBody, err := io.ReadAll(resp.Body)
  111. if err != nil {
  112. return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
  113. }
  114. common.CloseResponseBodyGracefully(resp)
  115. err = json.Unmarshal(responseBody, &aliTaskResponse)
  116. if err != nil {
  117. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  118. }
  119. if aliTaskResponse.Message != "" {
  120. common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  121. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  122. }
  123. aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
  124. if err != nil {
  125. return types.NewError(err, types.ErrorCodeBadResponse), nil
  126. }
  127. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  128. return types.WithOpenAIError(types.OpenAIError{
  129. Message: aliResponse.Output.Message,
  130. Type: "ali_error",
  131. Param: "",
  132. Code: aliResponse.Output.Code,
  133. }, resp.StatusCode), nil
  134. }
  135. fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
  136. jsonResponse, err := json.Marshal(fullTextResponse)
  137. if err != nil {
  138. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  139. }
  140. c.Writer.Header().Set("Content-Type", "application/json")
  141. c.Writer.WriteHeader(resp.StatusCode)
  142. c.Writer.Write(jsonResponse)
  143. return nil, &dto.Usage{}
  144. }