task_video.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/logger"
  11. "one-api/model"
  12. "one-api/relay"
  13. "one-api/relay/channel"
  14. relaycommon "one-api/relay/common"
  15. "time"
  16. )
  17. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  18. for channelId, taskIds := range taskChannelM {
  19. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  20. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  21. }
  22. }
  23. return nil
  24. }
  25. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  26. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  27. if len(taskIds) == 0 {
  28. return nil
  29. }
  30. cacheGetChannel, err := model.CacheGetChannel(channelId)
  31. if err != nil {
  32. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  33. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  34. "status": "FAILURE",
  35. "progress": "100%",
  36. })
  37. if errUpdate != nil {
  38. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  39. }
  40. return fmt.Errorf("CacheGetChannel failed: %w", err)
  41. }
  42. adaptor := relay.GetTaskAdaptor(platform)
  43. if adaptor == nil {
  44. return fmt.Errorf("video adaptor not found")
  45. }
  46. for _, taskId := range taskIds {
  47. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  48. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  49. }
  50. }
  51. return nil
  52. }
  53. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  54. baseURL := constant.ChannelBaseURLs[channel.Type]
  55. if channel.GetBaseURL() != "" {
  56. baseURL = channel.GetBaseURL()
  57. }
  58. task := taskM[taskId]
  59. if task == nil {
  60. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  61. return fmt.Errorf("task %s not found", taskId)
  62. }
  63. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  64. "task_id": taskId,
  65. "action": task.Action,
  66. })
  67. if err != nil {
  68. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  69. }
  70. //if resp.StatusCode != http.StatusOK {
  71. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  72. //}
  73. defer resp.Body.Close()
  74. responseBody, err := io.ReadAll(resp.Body)
  75. if err != nil {
  76. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  77. }
  78. taskResult := &relaycommon.TaskInfo{}
  79. // try parse as New API response format
  80. var responseItems dto.TaskResponse[model.Task]
  81. if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  82. t := responseItems.Data
  83. taskResult.TaskID = t.TaskID
  84. taskResult.Status = string(t.Status)
  85. taskResult.Url = t.FailReason
  86. taskResult.Progress = t.Progress
  87. taskResult.Reason = t.FailReason
  88. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  89. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  90. } else {
  91. task.Data = redactVideoResponseBody(responseBody)
  92. }
  93. now := time.Now().Unix()
  94. if taskResult.Status == "" {
  95. return fmt.Errorf("task %s status is empty", taskId)
  96. }
  97. task.Status = model.TaskStatus(taskResult.Status)
  98. switch taskResult.Status {
  99. case model.TaskStatusSubmitted:
  100. task.Progress = "10%"
  101. case model.TaskStatusQueued:
  102. task.Progress = "20%"
  103. case model.TaskStatusInProgress:
  104. task.Progress = "30%"
  105. if task.StartTime == 0 {
  106. task.StartTime = now
  107. }
  108. case model.TaskStatusSuccess:
  109. task.Progress = "100%"
  110. if task.FinishTime == 0 {
  111. task.FinishTime = now
  112. }
  113. if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
  114. task.FailReason = taskResult.Url
  115. }
  116. case model.TaskStatusFailure:
  117. task.Status = model.TaskStatusFailure
  118. task.Progress = "100%"
  119. if task.FinishTime == 0 {
  120. task.FinishTime = now
  121. }
  122. task.FailReason = taskResult.Reason
  123. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  124. quota := task.Quota
  125. if quota != 0 {
  126. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  127. logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
  128. }
  129. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
  130. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  131. }
  132. default:
  133. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  134. }
  135. if taskResult.Progress != "" {
  136. task.Progress = taskResult.Progress
  137. }
  138. if err := task.Update(); err != nil {
  139. common.SysLog("UpdateVideoTask task error: " + err.Error())
  140. }
  141. return nil
  142. }
  143. func redactVideoResponseBody(body []byte) []byte {
  144. var m map[string]any
  145. if err := json.Unmarshal(body, &m); err != nil {
  146. return body
  147. }
  148. resp, _ := m["response"].(map[string]any)
  149. if resp != nil {
  150. delete(resp, "bytesBase64Encoded")
  151. if v, ok := resp["video"].(string); ok {
  152. resp["video"] = truncateBase64(v)
  153. }
  154. if vs, ok := resp["videos"].([]any); ok {
  155. for i := range vs {
  156. if vm, ok := vs[i].(map[string]any); ok {
  157. delete(vm, "bytesBase64Encoded")
  158. }
  159. }
  160. }
  161. }
  162. b, err := json.Marshal(m)
  163. if err != nil {
  164. return body
  165. }
  166. return b
  167. }
  168. func truncateBase64(s string) string {
  169. const maxKeep = 256
  170. if len(s) <= maxKeep {
  171. return s
  172. }
  173. return s[:maxKeep] + "..."
  174. }