image.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package ali
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/dto"
  12. "one-api/logger"
  13. relaycommon "one-api/relay/common"
  14. "one-api/service"
  15. "one-api/types"
  16. "strings"
  17. "time"
  18. "github.com/gin-gonic/gin"
  19. )
  20. func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
  21. var imageRequest AliImageRequest
  22. imageRequest.Model = request.Model
  23. imageRequest.ResponseFormat = request.ResponseFormat
  24. logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
  25. if request.Extra != nil {
  26. if val, ok := request.Extra["parameters"]; ok {
  27. err := common.Unmarshal(val, &imageRequest.Parameters)
  28. if err != nil {
  29. return nil, fmt.Errorf("invalid parameters field: %w", err)
  30. }
  31. }
  32. if val, ok := request.Extra["input"]; ok {
  33. err := common.Unmarshal(val, &imageRequest.Input)
  34. if err != nil {
  35. return nil, fmt.Errorf("invalid input field: %w", err)
  36. }
  37. }
  38. }
  39. if imageRequest.Parameters == nil {
  40. imageRequest.Parameters = AliImageParameters{
  41. Size: strings.Replace(request.Size, "x", "*", -1),
  42. N: int(request.N),
  43. Watermark: request.Watermark,
  44. }
  45. }
  46. if imageRequest.Input == nil {
  47. imageRequest.Input = AliImageInput{
  48. Prompt: request.Prompt,
  49. }
  50. }
  51. return &imageRequest, nil
  52. }
  53. func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
  54. var imageRequest AliImageRequest
  55. imageRequest.Model = request.Model
  56. imageRequest.ResponseFormat = request.ResponseFormat
  57. mf := c.Request.MultipartForm
  58. if mf == nil {
  59. if _, err := c.MultipartForm(); err != nil {
  60. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  61. }
  62. mf = c.Request.MultipartForm
  63. }
  64. var imageFiles []*multipart.FileHeader
  65. var exists bool
  66. // First check for standard "image" field
  67. if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
  68. // If not found, check for "image[]" field
  69. if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
  70. // If still not found, iterate through all fields to find any that start with "image["
  71. foundArrayImages := false
  72. for fieldName, files := range mf.File {
  73. if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
  74. foundArrayImages = true
  75. imageFiles = append(imageFiles, files...)
  76. }
  77. }
  78. // If no image fields found at all
  79. if !foundArrayImages && (len(imageFiles) == 0) {
  80. return nil, errors.New("image is required")
  81. }
  82. }
  83. }
  84. if len(imageFiles) == 0 {
  85. return nil, errors.New("image is required")
  86. }
  87. if len(imageFiles) > 1 {
  88. return nil, errors.New("only one image is supported for qwen edit")
  89. }
  90. // 获取base64编码的图片
  91. var imageBase64s []string
  92. for _, file := range imageFiles {
  93. image, err := file.Open()
  94. if err != nil {
  95. return nil, errors.New("failed to open image file")
  96. }
  97. // 读取文件内容
  98. imageData, err := io.ReadAll(image)
  99. if err != nil {
  100. return nil, errors.New("failed to read image file")
  101. }
  102. // 获取MIME类型
  103. mimeType := http.DetectContentType(imageData)
  104. // 编码为base64
  105. base64Data := base64.StdEncoding.EncodeToString(imageData)
  106. // 构造data URL格式
  107. dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
  108. imageBase64s = append(imageBase64s, dataURL)
  109. image.Close()
  110. }
  111. //dto.MediaContent{}
  112. mediaContents := make([]AliMediaContent, len(imageBase64s))
  113. for i, b64 := range imageBase64s {
  114. mediaContents[i] = AliMediaContent{
  115. Image: b64,
  116. }
  117. }
  118. mediaContents = append(mediaContents, AliMediaContent{
  119. Text: request.Prompt,
  120. })
  121. imageRequest.Input = AliImageInput{
  122. Messages: []AliMessage{
  123. {
  124. Role: "user",
  125. Content: mediaContents,
  126. },
  127. },
  128. }
  129. imageRequest.Parameters = AliImageParameters{
  130. Watermark: request.Watermark,
  131. }
  132. return &imageRequest, nil
  133. }
  134. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  135. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  136. var aliResponse AliResponse
  137. req, err := http.NewRequest("GET", url, nil)
  138. if err != nil {
  139. return &aliResponse, err, nil
  140. }
  141. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  142. client := &http.Client{}
  143. resp, err := client.Do(req)
  144. if err != nil {
  145. common.SysLog("updateTask client.Do err: " + err.Error())
  146. return &aliResponse, err, nil
  147. }
  148. defer resp.Body.Close()
  149. responseBody, err := io.ReadAll(resp.Body)
  150. var response AliResponse
  151. err = common.Unmarshal(responseBody, &response)
  152. if err != nil {
  153. common.SysLog("updateTask NewDecoder err: " + err.Error())
  154. return &aliResponse, err, nil
  155. }
  156. return &response, nil, responseBody
  157. }
  158. func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  159. waitSeconds := 10
  160. step := 0
  161. maxStep := 20
  162. var taskResponse AliResponse
  163. var responseBody []byte
  164. for {
  165. logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
  166. step++
  167. rsp, err, body := updateTask(info, taskID)
  168. responseBody = body
  169. if err != nil {
  170. logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
  171. time.Sleep(time.Duration(waitSeconds) * time.Second)
  172. continue
  173. }
  174. if rsp.Output.TaskStatus == "" {
  175. return &taskResponse, responseBody, nil
  176. }
  177. switch rsp.Output.TaskStatus {
  178. case "FAILED":
  179. fallthrough
  180. case "CANCELED":
  181. fallthrough
  182. case "SUCCEEDED":
  183. fallthrough
  184. case "UNKNOWN":
  185. return rsp, responseBody, nil
  186. }
  187. if step >= maxStep {
  188. break
  189. }
  190. time.Sleep(time.Duration(waitSeconds) * time.Second)
  191. }
  192. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  193. }
  194. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  195. imageResponse := dto.ImageResponse{
  196. Created: info.StartTime.Unix(),
  197. }
  198. for _, data := range response.Output.Results {
  199. var b64Json string
  200. if responseFormat == "b64_json" {
  201. _, b64, err := service.GetImageFromUrl(data.Url)
  202. if err != nil {
  203. logger.LogError(c, "get_image_data_failed: "+err.Error())
  204. continue
  205. }
  206. b64Json = b64
  207. } else {
  208. b64Json = data.B64Image
  209. }
  210. imageResponse.Data = append(imageResponse.Data, dto.ImageData{
  211. Url: data.Url,
  212. B64Json: b64Json,
  213. RevisedPrompt: "",
  214. })
  215. }
  216. var mapResponse map[string]any
  217. _ = common.Unmarshal(originBody, &mapResponse)
  218. imageResponse.Extra = mapResponse
  219. return &imageResponse
  220. }
  221. func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  222. responseFormat := c.GetString("response_format")
  223. var aliTaskResponse AliResponse
  224. responseBody, err := io.ReadAll(resp.Body)
  225. if err != nil {
  226. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  227. }
  228. service.CloseResponseBodyGracefully(resp)
  229. err = common.Unmarshal(responseBody, &aliTaskResponse)
  230. if err != nil {
  231. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  232. }
  233. if aliTaskResponse.Message != "" {
  234. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  235. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  236. }
  237. aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
  238. if err != nil {
  239. return types.NewError(err, types.ErrorCodeBadResponse), nil
  240. }
  241. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  242. return types.WithOpenAIError(types.OpenAIError{
  243. Message: aliResponse.Output.Message,
  244. Type: "ali_error",
  245. Param: "",
  246. Code: aliResponse.Output.Code,
  247. }, resp.StatusCode), nil
  248. }
  249. fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
  250. jsonResponse, err := common.Marshal(fullTextResponse)
  251. if err != nil {
  252. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  253. }
  254. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  255. return nil, &dto.Usage{}
  256. }
  257. func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  258. var aliResponse AliResponse
  259. responseBody, err := io.ReadAll(resp.Body)
  260. if err != nil {
  261. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  262. }
  263. service.CloseResponseBodyGracefully(resp)
  264. err = common.Unmarshal(responseBody, &aliResponse)
  265. if err != nil {
  266. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  267. }
  268. if aliResponse.Message != "" {
  269. logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
  270. return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
  271. }
  272. var fullTextResponse dto.ImageResponse
  273. if len(aliResponse.Output.Choices) > 0 {
  274. fullTextResponse = dto.ImageResponse{
  275. Created: info.StartTime.Unix(),
  276. Data: []dto.ImageData{
  277. {
  278. Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
  279. B64Json: "",
  280. },
  281. },
  282. }
  283. }
  284. var mapResponse map[string]any
  285. _ = common.Unmarshal(responseBody, &mapResponse)
  286. fullTextResponse.Extra = mapResponse
  287. jsonResponse, err := common.Marshal(fullTextResponse)
  288. if err != nil {
  289. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  290. }
  291. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  292. return nil, &dto.Usage{}
  293. }