task_video.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/logger"
  11. "one-api/model"
  12. "one-api/relay"
  13. "one-api/relay/channel"
  14. relaycommon "one-api/relay/common"
  15. "time"
  16. )
  17. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  18. for channelId, taskIds := range taskChannelM {
  19. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  20. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  21. }
  22. }
  23. return nil
  24. }
  25. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  26. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  27. if len(taskIds) == 0 {
  28. return nil
  29. }
  30. cacheGetChannel, err := model.CacheGetChannel(channelId)
  31. if err != nil {
  32. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  33. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  34. "status": "FAILURE",
  35. "progress": "100%",
  36. })
  37. if errUpdate != nil {
  38. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  39. }
  40. return fmt.Errorf("CacheGetChannel failed: %w", err)
  41. }
  42. adaptor := relay.GetTaskAdaptor(platform)
  43. if adaptor == nil {
  44. return fmt.Errorf("video adaptor not found")
  45. }
  46. for _, taskId := range taskIds {
  47. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  48. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  49. }
  50. }
  51. return nil
  52. }
  53. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  54. baseURL := constant.ChannelBaseURLs[channel.Type]
  55. if channel.GetBaseURL() != "" {
  56. baseURL = channel.GetBaseURL()
  57. }
  58. task := taskM[taskId]
  59. if task == nil {
  60. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  61. return fmt.Errorf("task %s not found", taskId)
  62. }
  63. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  64. "task_id": taskId,
  65. "action": task.Action,
  66. })
  67. if err != nil {
  68. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  69. }
  70. //if resp.StatusCode != http.StatusOK {
  71. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  72. //}
  73. defer resp.Body.Close()
  74. responseBody, err := io.ReadAll(resp.Body)
  75. if err != nil {
  76. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  77. }
  78. taskResult := &relaycommon.TaskInfo{}
  79. // try parse as New API response format
  80. var responseItems dto.TaskResponse[model.Task]
  81. if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  82. t := responseItems.Data
  83. taskResult.TaskID = t.TaskID
  84. taskResult.Status = string(t.Status)
  85. taskResult.Url = t.FailReason
  86. taskResult.Progress = t.Progress
  87. taskResult.Reason = t.FailReason
  88. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  89. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  90. } else {
  91. task.Data = responseBody
  92. }
  93. now := time.Now().Unix()
  94. if taskResult.Status == "" {
  95. return fmt.Errorf("task %s status is empty", taskId)
  96. }
  97. task.Status = model.TaskStatus(taskResult.Status)
  98. switch taskResult.Status {
  99. case model.TaskStatusSubmitted:
  100. task.Progress = "10%"
  101. case model.TaskStatusQueued:
  102. task.Progress = "20%"
  103. case model.TaskStatusInProgress:
  104. task.Progress = "30%"
  105. if task.StartTime == 0 {
  106. task.StartTime = now
  107. }
  108. case model.TaskStatusSuccess:
  109. task.Progress = "100%"
  110. if task.FinishTime == 0 {
  111. task.FinishTime = now
  112. }
  113. task.FailReason = taskResult.Url
  114. case model.TaskStatusFailure:
  115. task.Status = model.TaskStatusFailure
  116. task.Progress = "100%"
  117. if task.FinishTime == 0 {
  118. task.FinishTime = now
  119. }
  120. task.FailReason = taskResult.Reason
  121. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  122. quota := task.Quota
  123. if quota != 0 {
  124. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  125. logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
  126. }
  127. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
  128. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  129. }
  130. default:
  131. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  132. }
  133. if taskResult.Progress != "" {
  134. task.Progress = taskResult.Progress
  135. }
  136. if err := task.Update(); err != nil {
  137. common.SysLog("UpdateVideoTask task error: " + err.Error())
  138. }
  139. return nil
  140. }