task_video.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/model"
  11. "one-api/relay"
  12. "one-api/relay/channel"
  13. relaycommon "one-api/relay/common"
  14. "time"
  15. )
  16. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  17. for channelId, taskIds := range taskChannelM {
  18. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  19. common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  20. }
  21. }
  22. return nil
  23. }
  24. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  25. common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  26. if len(taskIds) == 0 {
  27. return nil
  28. }
  29. cacheGetChannel, err := model.CacheGetChannel(channelId)
  30. if err != nil {
  31. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  32. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  33. "status": "FAILURE",
  34. "progress": "100%",
  35. })
  36. if errUpdate != nil {
  37. common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  38. }
  39. return fmt.Errorf("CacheGetChannel failed: %w", err)
  40. }
  41. adaptor := relay.GetTaskAdaptor(platform)
  42. if adaptor == nil {
  43. return fmt.Errorf("video adaptor not found")
  44. }
  45. for _, taskId := range taskIds {
  46. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  47. common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  48. }
  49. }
  50. return nil
  51. }
  52. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  53. baseURL := constant.ChannelBaseURLs[channel.Type]
  54. if channel.GetBaseURL() != "" {
  55. baseURL = channel.GetBaseURL()
  56. }
  57. task := taskM[taskId]
  58. if task == nil {
  59. common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  60. return fmt.Errorf("task %s not found", taskId)
  61. }
  62. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  63. "task_id": taskId,
  64. "action": task.Action,
  65. })
  66. if err != nil {
  67. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  68. }
  69. //if resp.StatusCode != http.StatusOK {
  70. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  71. //}
  72. defer resp.Body.Close()
  73. responseBody, err := io.ReadAll(resp.Body)
  74. if err != nil {
  75. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  76. }
  77. taskResult := &relaycommon.TaskInfo{}
  78. // try parse as New API response format
  79. var responseItems dto.TaskResponse[model.Task]
  80. if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  81. t := responseItems.Data
  82. taskResult.TaskID = t.TaskID
  83. taskResult.Status = string(t.Status)
  84. taskResult.Url = t.FailReason
  85. taskResult.Progress = t.Progress
  86. taskResult.Reason = t.FailReason
  87. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  88. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  89. } else {
  90. task.Data = responseBody
  91. }
  92. now := time.Now().Unix()
  93. if taskResult.Status == "" {
  94. return fmt.Errorf("task %s status is empty", taskId)
  95. }
  96. task.Status = model.TaskStatus(taskResult.Status)
  97. switch taskResult.Status {
  98. case model.TaskStatusSubmitted:
  99. task.Progress = "10%"
  100. case model.TaskStatusQueued:
  101. task.Progress = "20%"
  102. case model.TaskStatusInProgress:
  103. task.Progress = "30%"
  104. if task.StartTime == 0 {
  105. task.StartTime = now
  106. }
  107. case model.TaskStatusSuccess:
  108. task.Progress = "100%"
  109. if task.FinishTime == 0 {
  110. task.FinishTime = now
  111. }
  112. task.FailReason = taskResult.Url
  113. case model.TaskStatusFailure:
  114. task.Status = model.TaskStatusFailure
  115. task.Progress = "100%"
  116. if task.FinishTime == 0 {
  117. task.FinishTime = now
  118. }
  119. task.FailReason = taskResult.Reason
  120. common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  121. quota := task.Quota
  122. if quota != 0 {
  123. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  124. common.LogError(ctx, "Failed to increase user quota: "+err.Error())
  125. }
  126. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
  127. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  128. }
  129. default:
  130. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  131. }
  132. if taskResult.Progress != "" {
  133. task.Progress = taskResult.Progress
  134. }
  135. if err := task.Update(); err != nil {
  136. common.SysError("UpdateVideoTask task error: " + err.Error())
  137. }
  138. return nil
  139. }