task_video.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/logger"
  11. "one-api/model"
  12. "one-api/relay"
  13. "one-api/relay/channel"
  14. relaycommon "one-api/relay/common"
  15. "one-api/setting/ratio_setting"
  16. "time"
  17. )
  18. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  19. for channelId, taskIds := range taskChannelM {
  20. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  21. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  22. }
  23. }
  24. return nil
  25. }
  26. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  27. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  28. if len(taskIds) == 0 {
  29. return nil
  30. }
  31. cacheGetChannel, err := model.CacheGetChannel(channelId)
  32. if err != nil {
  33. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  34. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  35. "status": "FAILURE",
  36. "progress": "100%",
  37. })
  38. if errUpdate != nil {
  39. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  40. }
  41. return fmt.Errorf("CacheGetChannel failed: %w", err)
  42. }
  43. adaptor := relay.GetTaskAdaptor(platform)
  44. if adaptor == nil {
  45. return fmt.Errorf("video adaptor not found")
  46. }
  47. for _, taskId := range taskIds {
  48. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  49. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  50. }
  51. }
  52. return nil
  53. }
  54. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  55. baseURL := constant.ChannelBaseURLs[channel.Type]
  56. if channel.GetBaseURL() != "" {
  57. baseURL = channel.GetBaseURL()
  58. }
  59. task := taskM[taskId]
  60. if task == nil {
  61. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  62. return fmt.Errorf("task %s not found", taskId)
  63. }
  64. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  65. "task_id": taskId,
  66. "action": task.Action,
  67. })
  68. if err != nil {
  69. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  70. }
  71. //if resp.StatusCode != http.StatusOK {
  72. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  73. //}
  74. defer resp.Body.Close()
  75. responseBody, err := io.ReadAll(resp.Body)
  76. if err != nil {
  77. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  78. }
  79. taskResult := &relaycommon.TaskInfo{}
  80. // try parse as New API response format
  81. var responseItems dto.TaskResponse[model.Task]
  82. if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  83. t := responseItems.Data
  84. taskResult.TaskID = t.TaskID
  85. taskResult.Status = string(t.Status)
  86. taskResult.Url = t.FailReason
  87. taskResult.Progress = t.Progress
  88. taskResult.Reason = t.FailReason
  89. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  90. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  91. } else {
  92. task.Data = redactVideoResponseBody(responseBody)
  93. }
  94. now := time.Now().Unix()
  95. if taskResult.Status == "" {
  96. return fmt.Errorf("task %s status is empty", taskId)
  97. }
  98. task.Status = model.TaskStatus(taskResult.Status)
  99. switch taskResult.Status {
  100. case model.TaskStatusSubmitted:
  101. task.Progress = "10%"
  102. case model.TaskStatusQueued:
  103. task.Progress = "20%"
  104. case model.TaskStatusInProgress:
  105. task.Progress = "30%"
  106. if task.StartTime == 0 {
  107. task.StartTime = now
  108. }
  109. case model.TaskStatusSuccess:
  110. task.Progress = "100%"
  111. if task.FinishTime == 0 {
  112. task.FinishTime = now
  113. }
  114. if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
  115. task.FailReason = taskResult.Url
  116. }
  117. // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
  118. if taskResult.TotalTokens > 0 {
  119. // 获取模型名称
  120. var taskData map[string]interface{}
  121. if err := json.Unmarshal(task.Data, &taskData); err == nil {
  122. if modelName, ok := taskData["model"].(string); ok && modelName != "" {
  123. // 获取模型价格和倍率
  124. modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
  125. // 只有配置了倍率(非固定价格)时才按 token 重新计费
  126. if hasRatioSetting && modelRatio > 0 {
  127. // 获取用户和组的倍率信息
  128. user, err := model.GetUserById(task.UserId, false)
  129. if err == nil {
  130. groupRatio := ratio_setting.GetGroupRatio(user.Group)
  131. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
  132. var finalGroupRatio float64
  133. if hasUserGroupRatio {
  134. finalGroupRatio = userGroupRatio
  135. } else {
  136. finalGroupRatio = groupRatio
  137. }
  138. // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
  139. actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
  140. // 计算差额
  141. preConsumedQuota := task.Quota
  142. quotaDelta := actualQuota - preConsumedQuota
  143. if quotaDelta > 0 {
  144. // 需要补扣费
  145. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
  146. task.TaskID,
  147. logger.LogQuota(quotaDelta),
  148. logger.LogQuota(actualQuota),
  149. logger.LogQuota(preConsumedQuota),
  150. taskResult.TotalTokens,
  151. ))
  152. if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
  153. logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
  154. } else {
  155. model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
  156. model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
  157. task.Quota = actualQuota // 更新任务记录的实际扣费额度
  158. // 记录消费日志
  159. logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
  160. modelRatio, finalGroupRatio, taskResult.TotalTokens,
  161. logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
  162. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  163. }
  164. } else if quotaDelta < 0 {
  165. // 需要退还多扣的费用
  166. refundQuota := -quotaDelta
  167. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
  168. task.TaskID,
  169. logger.LogQuota(refundQuota),
  170. logger.LogQuota(actualQuota),
  171. logger.LogQuota(preConsumedQuota),
  172. taskResult.TotalTokens,
  173. ))
  174. if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
  175. logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
  176. } else {
  177. task.Quota = actualQuota // 更新任务记录的实际扣费额度
  178. // 记录退款日志
  179. logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
  180. modelRatio, finalGroupRatio, taskResult.TotalTokens,
  181. logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
  182. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  183. }
  184. } else {
  185. // quotaDelta == 0, 预扣费刚好准确
  186. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
  187. task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
  188. }
  189. }
  190. }
  191. }
  192. }
  193. }
  194. case model.TaskStatusFailure:
  195. task.Status = model.TaskStatusFailure
  196. task.Progress = "100%"
  197. if task.FinishTime == 0 {
  198. task.FinishTime = now
  199. }
  200. task.FailReason = taskResult.Reason
  201. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  202. quota := task.Quota
  203. if quota != 0 {
  204. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  205. logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
  206. }
  207. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
  208. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  209. }
  210. default:
  211. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  212. }
  213. if taskResult.Progress != "" {
  214. task.Progress = taskResult.Progress
  215. }
  216. if err := task.Update(); err != nil {
  217. common.SysLog("UpdateVideoTask task error: " + err.Error())
  218. }
  219. return nil
  220. }
  221. func redactVideoResponseBody(body []byte) []byte {
  222. var m map[string]any
  223. if err := json.Unmarshal(body, &m); err != nil {
  224. return body
  225. }
  226. resp, _ := m["response"].(map[string]any)
  227. if resp != nil {
  228. delete(resp, "bytesBase64Encoded")
  229. if v, ok := resp["video"].(string); ok {
  230. resp["video"] = truncateBase64(v)
  231. }
  232. if vs, ok := resp["videos"].([]any); ok {
  233. for i := range vs {
  234. if vm, ok := vs[i].(map[string]any); ok {
  235. delete(vm, "bytesBase64Encoded")
  236. }
  237. }
  238. }
  239. }
  240. b, err := json.Marshal(m)
  241. if err != nil {
  242. return body
  243. }
  244. return b
  245. }
  246. func truncateBase64(s string) string {
  247. const maxKeep = 256
  248. if len(s) <= maxKeep {
  249. return s
  250. }
  251. return s[:maxKeep] + "..."
  252. }