midjourney.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "log"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/model"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. func UpdateMidjourneyTask() {
  18. //revocer
  19. imageModel := "midjourney"
  20. for {
  21. defer func() {
  22. if err := recover(); err != nil {
  23. log.Printf("UpdateMidjourneyTask panic: %v", err)
  24. }
  25. }()
  26. time.Sleep(time.Duration(15) * time.Second)
  27. tasks := model.GetAllUnFinishTasks()
  28. if len(tasks) != 0 {
  29. log.Printf("检测到未完成的任务数有: %v", len(tasks))
  30. for _, task := range tasks {
  31. log.Printf("未完成的任务信息: %v", task)
  32. midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
  33. if err != nil {
  34. log.Printf("UpdateMidjourneyTask: %v", err)
  35. task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
  36. task.Status = "FAILURE"
  37. task.Progress = "100%"
  38. err := task.Update()
  39. if err != nil {
  40. log.Printf("UpdateMidjourneyTask error: %v", err)
  41. }
  42. continue
  43. }
  44. requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
  45. log.Printf("requestUrl: %s", requestUrl)
  46. req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
  47. if err != nil {
  48. log.Printf("UpdateMidjourneyTask error: %v", err)
  49. continue
  50. }
  51. // 设置超时时间
  52. timeout := time.Second * 5
  53. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  54. defer cancel()
  55. // 使用带有超时的 context 创建新的请求
  56. req = req.WithContext(ctx)
  57. req.Header.Set("Content-Type", "application/json")
  58. req.Header.Set("Authorization", "Bearer midjourney-proxy")
  59. req.Header.Set("mj-api-secret", midjourneyChannel.Key)
  60. resp, err := httpClient.Do(req)
  61. if err != nil {
  62. log.Printf("UpdateMidjourneyTask error: %v", err)
  63. continue
  64. }
  65. defer resp.Body.Close()
  66. responseBody, err := io.ReadAll(resp.Body)
  67. log.Printf("responseBody: %s", string(responseBody))
  68. var responseItem Midjourney
  69. // err = json.NewDecoder(resp.Body).Decode(&responseItem)
  70. err = json.Unmarshal(responseBody, &responseItem)
  71. if err != nil {
  72. if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
  73. var responseWithoutStatus MidjourneyWithoutStatus
  74. var responseStatus MidjourneyStatus
  75. err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
  76. err2 := json.Unmarshal(responseBody, &responseStatus)
  77. if err1 == nil && err2 == nil {
  78. jsonData, err3 := json.Marshal(responseWithoutStatus)
  79. if err3 != nil {
  80. log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
  81. continue
  82. }
  83. err4 := json.Unmarshal(jsonData, &responseStatus)
  84. if err4 != nil {
  85. log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
  86. continue
  87. }
  88. responseItem.Status = strconv.Itoa(responseStatus.Status)
  89. } else {
  90. log.Printf("UpdateMidjourneyTask error3: %v", err)
  91. continue
  92. }
  93. } else {
  94. log.Printf("UpdateMidjourneyTask error4: %v", err)
  95. continue
  96. }
  97. }
  98. task.Code = 1
  99. task.Progress = responseItem.Progress
  100. task.PromptEn = responseItem.PromptEn
  101. task.State = responseItem.State
  102. task.SubmitTime = responseItem.SubmitTime
  103. task.StartTime = responseItem.StartTime
  104. task.FinishTime = responseItem.FinishTime
  105. task.ImageUrl = responseItem.ImageUrl
  106. task.Status = responseItem.Status
  107. task.FailReason = responseItem.FailReason
  108. if task.Progress != "100%" && responseItem.FailReason != "" {
  109. log.Println(task.MjId + " 构建失败," + task.FailReason)
  110. task.Progress = "100%"
  111. err = model.CacheUpdateUserQuota(task.UserId)
  112. if err != nil {
  113. log.Println("error update user quota cache: " + err.Error())
  114. } else {
  115. modelRatio := common.GetModelRatio(imageModel)
  116. groupRatio := common.GetGroupRatio("default")
  117. ratio := modelRatio * groupRatio
  118. quota := int(ratio * 1 * 1000)
  119. if quota != 0 {
  120. err := model.IncreaseUserQuota(task.UserId, quota)
  121. if err != nil {
  122. log.Println("fail to increase user quota")
  123. }
  124. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
  125. model.RecordLog(task.UserId, 1, logContent)
  126. }
  127. }
  128. }
  129. err = task.Update()
  130. if err != nil {
  131. log.Printf("UpdateMidjourneyTask error5: %v", err)
  132. }
  133. log.Printf("UpdateMidjourneyTask success: %v", task)
  134. }
  135. }
  136. }
  137. }
  138. func GetAllMidjourney(c *gin.Context) {
  139. p, _ := strconv.Atoi(c.Query("p"))
  140. if p < 0 {
  141. p = 0
  142. }
  143. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
  144. if logs == nil {
  145. logs = make([]*model.Midjourney, 0)
  146. }
  147. c.JSON(200, gin.H{
  148. "success": true,
  149. "message": "",
  150. "data": logs,
  151. })
  152. }
  153. func GetUserMidjourney(c *gin.Context) {
  154. p, _ := strconv.Atoi(c.Query("p"))
  155. if p < 0 {
  156. p = 0
  157. }
  158. userId := c.GetInt("id")
  159. log.Printf("userId = %d \n", userId)
  160. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
  161. if logs == nil {
  162. logs = make([]*model.Midjourney, 0)
  163. }
  164. c.JSON(200, gin.H{
  165. "success": true,
  166. "message": "",
  167. "data": logs,
  168. })
  169. }