task_video.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/constant"
  10. "one-api/model"
  11. "one-api/relay"
  12. "one-api/relay/channel"
  13. )
  14. func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  15. for channelId, taskIds := range taskChannelM {
  16. if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
  17. common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  18. }
  19. }
  20. return nil
  21. }
  22. func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  23. common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  24. if len(taskIds) == 0 {
  25. return nil
  26. }
  27. cacheGetChannel, err := model.CacheGetChannel(channelId)
  28. if err != nil {
  29. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  30. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  31. "status": "FAILURE",
  32. "progress": "100%",
  33. })
  34. if errUpdate != nil {
  35. common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  36. }
  37. return fmt.Errorf("CacheGetChannel failed: %w", err)
  38. }
  39. adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
  40. if adaptor == nil {
  41. return fmt.Errorf("video adaptor not found")
  42. }
  43. for _, taskId := range taskIds {
  44. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  45. common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  46. }
  47. }
  48. return nil
  49. }
  50. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  51. baseURL := common.ChannelBaseURLs[channel.Type]
  52. if channel.GetBaseURL() != "" {
  53. baseURL = channel.GetBaseURL()
  54. }
  55. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  56. "task_id": taskId,
  57. })
  58. if err != nil {
  59. return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
  60. }
  61. if resp.StatusCode != http.StatusOK {
  62. return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
  63. }
  64. defer resp.Body.Close()
  65. responseBody, err := io.ReadAll(resp.Body)
  66. if err != nil {
  67. return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
  68. }
  69. var responseItem map[string]interface{}
  70. err = json.Unmarshal(responseBody, &responseItem)
  71. if err != nil {
  72. common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
  73. return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
  74. }
  75. code, _ := responseItem["code"].(float64)
  76. if code != 0 {
  77. return fmt.Errorf("video task fetch failed for task %s", taskId)
  78. }
  79. data, ok := responseItem["data"].(map[string]interface{})
  80. if !ok {
  81. common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
  82. return fmt.Errorf("video task data format error for task %s", taskId)
  83. }
  84. task := taskM[taskId]
  85. if task == nil {
  86. common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  87. return fmt.Errorf("task %s not found", taskId)
  88. }
  89. if status, ok := data["task_status"].(string); ok {
  90. switch status {
  91. case "submitted", "queued":
  92. task.Status = model.TaskStatusSubmitted
  93. case "processing":
  94. task.Status = model.TaskStatusInProgress
  95. case "succeed":
  96. task.Status = model.TaskStatusSuccess
  97. task.Progress = "100%"
  98. if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
  99. task.FailReason = url
  100. } else {
  101. common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
  102. }
  103. case "failed":
  104. task.Status = model.TaskStatusFailure
  105. task.Progress = "100%"
  106. if reason, ok := data["fail_reason"].(string); ok {
  107. task.FailReason = reason
  108. }
  109. }
  110. }
  111. // If task failed, refund quota
  112. if task.Status == model.TaskStatusFailure {
  113. common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  114. quota := task.Quota
  115. if quota != 0 {
  116. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  117. common.LogError(ctx, "Failed to increase user quota: "+err.Error())
  118. }
  119. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
  120. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  121. }
  122. }
  123. task.Data = responseBody
  124. if err := task.Update(); err != nil {
  125. common.SysError("UpdateVideoTask task error: " + err.Error())
  126. }
  127. return nil
  128. }