image.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package ali
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/service"
  13. "strings"
  14. "time"
  15. )
  16. func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
  17. var imageRequest AliImageRequest
  18. imageRequest.Input.Prompt = request.Prompt
  19. imageRequest.Model = request.Model
  20. imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
  21. imageRequest.Parameters.N = request.N
  22. imageRequest.ResponseFormat = request.ResponseFormat
  23. return &imageRequest
  24. }
  25. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  26. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
  27. var aliResponse AliResponse
  28. req, err := http.NewRequest("GET", url, nil)
  29. if err != nil {
  30. return &aliResponse, err, nil
  31. }
  32. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  33. client := &http.Client{}
  34. resp, err := client.Do(req)
  35. if err != nil {
  36. common.SysError("updateTask client.Do err: " + err.Error())
  37. return &aliResponse, err, nil
  38. }
  39. defer resp.Body.Close()
  40. responseBody, err := io.ReadAll(resp.Body)
  41. var response AliResponse
  42. err = json.Unmarshal(responseBody, &response)
  43. if err != nil {
  44. common.SysError("updateTask NewDecoder err: " + err.Error())
  45. return &aliResponse, err, nil
  46. }
  47. return &response, nil, responseBody
  48. }
  49. func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  50. waitSeconds := 3
  51. step := 0
  52. maxStep := 20
  53. var taskResponse AliResponse
  54. var responseBody []byte
  55. for {
  56. step++
  57. rsp, err, body := updateTask(info, taskID)
  58. responseBody = body
  59. if err != nil {
  60. return &taskResponse, responseBody, err
  61. }
  62. if rsp.Output.TaskStatus == "" {
  63. return &taskResponse, responseBody, nil
  64. }
  65. switch rsp.Output.TaskStatus {
  66. case "FAILED":
  67. fallthrough
  68. case "CANCELED":
  69. fallthrough
  70. case "SUCCEEDED":
  71. fallthrough
  72. case "UNKNOWN":
  73. return rsp, responseBody, nil
  74. }
  75. if step >= maxStep {
  76. break
  77. }
  78. time.Sleep(time.Duration(waitSeconds) * time.Second)
  79. }
  80. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  81. }
  82. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  83. imageResponse := dto.ImageResponse{
  84. Created: info.StartTime.Unix(),
  85. }
  86. for _, data := range response.Output.Results {
  87. var b64Json string
  88. if responseFormat == "b64_json" {
  89. _, b64, err := service.GetImageFromUrl(data.Url)
  90. if err != nil {
  91. common.LogError(c, "get_image_data_failed: "+err.Error())
  92. continue
  93. }
  94. b64Json = b64
  95. } else {
  96. b64Json = data.B64Image
  97. }
  98. imageResponse.Data = append(imageResponse.Data, dto.ImageData{
  99. Url: data.Url,
  100. B64Json: b64Json,
  101. RevisedPrompt: "",
  102. })
  103. }
  104. return &imageResponse
  105. }
  106. func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  107. responseFormat := c.GetString("response_format")
  108. var aliTaskResponse AliResponse
  109. responseBody, err := io.ReadAll(resp.Body)
  110. if err != nil {
  111. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  112. }
  113. common.CloseResponseBodyGracefully(resp)
  114. err = json.Unmarshal(responseBody, &aliTaskResponse)
  115. if err != nil {
  116. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  117. }
  118. if aliTaskResponse.Message != "" {
  119. common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  120. return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
  121. }
  122. aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
  123. if err != nil {
  124. return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
  125. }
  126. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  127. return &dto.OpenAIErrorWithStatusCode{
  128. Error: dto.OpenAIError{
  129. Message: aliResponse.Output.Message,
  130. Type: "ali_error",
  131. Param: "",
  132. Code: aliResponse.Output.Code,
  133. },
  134. StatusCode: resp.StatusCode,
  135. }, nil
  136. }
  137. fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
  138. jsonResponse, err := json.Marshal(fullTextResponse)
  139. if err != nil {
  140. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  141. }
  142. c.Writer.Header().Set("Content-Type", "application/json")
  143. c.Writer.WriteHeader(resp.StatusCode)
  144. _, err = c.Writer.Write(jsonResponse)
  145. return nil, nil
  146. }