1
0

task_video.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package controller
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/model"
  9. "one-api/relay"
  10. "one-api/relay/channel"
  11. "time"
  12. )
  13. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  14. for channelId, taskIds := range taskChannelM {
  15. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  16. common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  17. }
  18. }
  19. return nil
  20. }
  21. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  22. common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  23. if len(taskIds) == 0 {
  24. return nil
  25. }
  26. cacheGetChannel, err := model.CacheGetChannel(channelId)
  27. if err != nil {
  28. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  29. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  30. "status": "FAILURE",
  31. "progress": "100%",
  32. })
  33. if errUpdate != nil {
  34. common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  35. }
  36. return fmt.Errorf("CacheGetChannel failed: %w", err)
  37. }
  38. adaptor := relay.GetTaskAdaptor(platform)
  39. if adaptor == nil {
  40. return fmt.Errorf("video adaptor not found")
  41. }
  42. for _, taskId := range taskIds {
  43. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  44. common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  45. }
  46. }
  47. return nil
  48. }
  49. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  50. baseURL := constant.ChannelBaseURLs[channel.Type]
  51. if channel.GetBaseURL() != "" {
  52. baseURL = channel.GetBaseURL()
  53. }
  54. task := taskM[taskId]
  55. if task == nil {
  56. common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  57. return fmt.Errorf("task %s not found", taskId)
  58. }
  59. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  60. "task_id": taskId,
  61. "action": task.Action,
  62. })
  63. if err != nil {
  64. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  65. }
  66. //if resp.StatusCode != http.StatusOK {
  67. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  68. //}
  69. defer resp.Body.Close()
  70. responseBody, err := io.ReadAll(resp.Body)
  71. if err != nil {
  72. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  73. }
  74. taskResult, err := adaptor.ParseTaskResult(responseBody)
  75. if err != nil {
  76. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  77. }
  78. //if taskResult.Code != 0 {
  79. // return fmt.Errorf("video task fetch failed for task %s", taskId)
  80. //}
  81. now := time.Now().Unix()
  82. if taskResult.Status == "" {
  83. return fmt.Errorf("task %s status is empty", taskId)
  84. }
  85. task.Status = model.TaskStatus(taskResult.Status)
  86. switch taskResult.Status {
  87. case model.TaskStatusSubmitted:
  88. task.Progress = "10%"
  89. case model.TaskStatusQueued:
  90. task.Progress = "20%"
  91. case model.TaskStatusInProgress:
  92. task.Progress = "30%"
  93. if task.StartTime == 0 {
  94. task.StartTime = now
  95. }
  96. case model.TaskStatusSuccess:
  97. task.Progress = "100%"
  98. if task.FinishTime == 0 {
  99. task.FinishTime = now
  100. }
  101. task.FailReason = taskResult.Url
  102. case model.TaskStatusFailure:
  103. task.Status = model.TaskStatusFailure
  104. task.Progress = "100%"
  105. if task.FinishTime == 0 {
  106. task.FinishTime = now
  107. }
  108. task.FailReason = taskResult.Reason
  109. common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  110. quota := task.Quota
  111. if quota != 0 {
  112. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  113. common.LogError(ctx, "Failed to increase user quota: "+err.Error())
  114. }
  115. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
  116. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  117. }
  118. default:
  119. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  120. }
  121. if taskResult.Progress != "" {
  122. task.Progress = taskResult.Progress
  123. }
  124. task.Data = responseBody
  125. if err := task.Update(); err != nil {
  126. common.SysError("UpdateVideoTask task error: " + err.Error())
  127. }
  128. return nil
  129. }