task_video.go 11 KB


  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/relay"
  14. "github.com/QuantumNous/new-api/relay/channel"
  15. relaycommon "github.com/QuantumNous/new-api/relay/common"
  16. "github.com/QuantumNous/new-api/setting/ratio_setting"
  17. )
  18. func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  19. for channelId, taskIds := range taskChannelM {
  20. if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
  21. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  22. }
  23. }
  24. return nil
  25. }
  26. func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  27. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  28. if len(taskIds) == 0 {
  29. return nil
  30. }
  31. cacheGetChannel, err := model.CacheGetChannel(channelId)
  32. if err != nil {
  33. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  34. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  35. "status": "FAILURE",
  36. "progress": "100%",
  37. })
  38. if errUpdate != nil {
  39. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  40. }
  41. return fmt.Errorf("CacheGetChannel failed: %w", err)
  42. }
  43. adaptor := relay.GetTaskAdaptor(platform)
  44. if adaptor == nil {
  45. return fmt.Errorf("video adaptor not found")
  46. }
  47. info := &relaycommon.RelayInfo{}
  48. info.ChannelMeta = &relaycommon.ChannelMeta{
  49. ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
  50. }
  51. info.ApiKey = cacheGetChannel.Key
  52. adaptor.Init(info)
  53. for _, taskId := range taskIds {
  54. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  55. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  56. }
  57. }
  58. return nil
  59. }
  60. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  61. baseURL := constant.ChannelBaseURLs[channel.Type]
  62. if channel.GetBaseURL() != "" {
  63. baseURL = channel.GetBaseURL()
  64. }
  65. proxy := channel.GetSetting().Proxy
  66. task := taskM[taskId]
  67. if task == nil {
  68. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  69. return fmt.Errorf("task %s not found", taskId)
  70. }
  71. key := channel.Key
  72. privateData := task.PrivateData
  73. if privateData.Key != "" {
  74. key = privateData.Key
  75. }
  76. resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
  77. "task_id": taskId,
  78. "action": task.Action,
  79. }, proxy)
  80. if err != nil {
  81. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  82. }
  83. //if resp.StatusCode != http.StatusOK {
  84. //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
  85. //}
  86. defer resp.Body.Close()
  87. responseBody, err := io.ReadAll(resp.Body)
  88. if err != nil {
  89. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  90. }
  91. logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
  92. taskResult := &relaycommon.TaskInfo{}
  93. // try parse as New API response format
  94. var responseItems dto.TaskResponse[model.Task]
  95. if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  96. logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
  97. t := responseItems.Data
  98. taskResult.TaskID = t.TaskID
  99. taskResult.Status = string(t.Status)
  100. taskResult.Url = t.FailReason
  101. taskResult.Progress = t.Progress
  102. taskResult.Reason = t.FailReason
  103. task.Data = t.Data
  104. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  105. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  106. } else {
  107. task.Data = redactVideoResponseBody(responseBody)
  108. }
  109. logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
  110. now := time.Now().Unix()
  111. if taskResult.Status == "" {
  112. //return fmt.Errorf("task %s status is empty", taskId)
  113. taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
  114. }
  115. // 记录原本的状态,防止重复退款
  116. shouldRefund := false
  117. quota := task.Quota
  118. preStatus := task.Status
  119. task.Status = model.TaskStatus(taskResult.Status)
  120. switch taskResult.Status {
  121. case model.TaskStatusSubmitted:
  122. task.Progress = "10%"
  123. case model.TaskStatusQueued:
  124. task.Progress = "20%"
  125. case model.TaskStatusInProgress:
  126. task.Progress = "30%"
  127. if task.StartTime == 0 {
  128. task.StartTime = now
  129. }
  130. case model.TaskStatusSuccess:
  131. task.Progress = "100%"
  132. if task.FinishTime == 0 {
  133. task.FinishTime = now
  134. }
  135. if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
  136. task.FailReason = taskResult.Url
  137. }
  138. // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
  139. if taskResult.TotalTokens > 0 {
  140. // 获取模型名称
  141. var taskData map[string]interface{}
  142. if err := json.Unmarshal(task.Data, &taskData); err == nil {
  143. if modelName, ok := taskData["model"].(string); ok && modelName != "" {
  144. // 获取模型价格和倍率
  145. modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
  146. // 只有配置了倍率(非固定价格)时才按 token 重新计费
  147. if hasRatioSetting && modelRatio > 0 {
  148. // 获取用户和组的倍率信息
  149. group := task.Group
  150. if group == "" {
  151. user, err := model.GetUserById(task.UserId, false)
  152. if err == nil {
  153. group = user.Group
  154. }
  155. }
  156. if group != "" {
  157. groupRatio := ratio_setting.GetGroupRatio(group)
  158. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
  159. var finalGroupRatio float64
  160. if hasUserGroupRatio {
  161. finalGroupRatio = userGroupRatio
  162. } else {
  163. finalGroupRatio = groupRatio
  164. }
  165. // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
  166. actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
  167. // 计算差额
  168. preConsumedQuota := task.Quota
  169. quotaDelta := actualQuota - preConsumedQuota
  170. if quotaDelta > 0 {
  171. // 需要补扣费
  172. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
  173. task.TaskID,
  174. logger.LogQuota(quotaDelta),
  175. logger.LogQuota(actualQuota),
  176. logger.LogQuota(preConsumedQuota),
  177. taskResult.TotalTokens,
  178. ))
  179. if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
  180. logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
  181. } else {
  182. model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
  183. model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
  184. task.Quota = actualQuota // 更新任务记录的实际扣费额度
  185. // 记录消费日志
  186. logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
  187. modelRatio, finalGroupRatio, taskResult.TotalTokens,
  188. logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
  189. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  190. }
  191. } else if quotaDelta < 0 {
  192. // 需要退还多扣的费用
  193. refundQuota := -quotaDelta
  194. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
  195. task.TaskID,
  196. logger.LogQuota(refundQuota),
  197. logger.LogQuota(actualQuota),
  198. logger.LogQuota(preConsumedQuota),
  199. taskResult.TotalTokens,
  200. ))
  201. if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
  202. logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
  203. } else {
  204. task.Quota = actualQuota // 更新任务记录的实际扣费额度
  205. // 记录退款日志
  206. logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
  207. modelRatio, finalGroupRatio, taskResult.TotalTokens,
  208. logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
  209. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  210. }
  211. } else {
  212. // quotaDelta == 0, 预扣费刚好准确
  213. logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
  214. task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
  215. }
  216. }
  217. }
  218. }
  219. }
  220. }
  221. case model.TaskStatusFailure:
  222. logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
  223. task.Status = model.TaskStatusFailure
  224. task.Progress = "100%"
  225. if task.FinishTime == 0 {
  226. task.FinishTime = now
  227. }
  228. task.FailReason = taskResult.Reason
  229. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  230. taskResult.Progress = "100%"
  231. if quota != 0 {
  232. if preStatus != model.TaskStatusFailure {
  233. shouldRefund = true
  234. } else {
  235. logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
  236. }
  237. }
  238. default:
  239. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  240. }
  241. if taskResult.Progress != "" {
  242. task.Progress = taskResult.Progress
  243. }
  244. if err := task.Update(); err != nil {
  245. common.SysLog("UpdateVideoTask task error: " + err.Error())
  246. shouldRefund = false
  247. }
  248. if shouldRefund {
  249. // 任务失败且之前状态不是失败才退还额度,防止重复退还
  250. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  251. logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
  252. }
  253. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
  254. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  255. }
  256. return nil
  257. }
  258. func redactVideoResponseBody(body []byte) []byte {
  259. var m map[string]any
  260. if err := json.Unmarshal(body, &m); err != nil {
  261. return body
  262. }
  263. resp, _ := m["response"].(map[string]any)
  264. if resp != nil {
  265. delete(resp, "bytesBase64Encoded")
  266. if v, ok := resp["video"].(string); ok {
  267. resp["video"] = truncateBase64(v)
  268. }
  269. if vs, ok := resp["videos"].([]any); ok {
  270. for i := range vs {
  271. if vm, ok := vs[i].(map[string]any); ok {
  272. delete(vm, "bytesBase64Encoded")
  273. }
  274. }
  275. }
  276. }
  277. b, err := json.Marshal(m)
  278. if err != nil {
  279. return body
  280. }
  281. return b
  282. }
  283. func truncateBase64(s string) string {
  284. const maxKeep = 256
  285. if len(s) <= maxKeep {
  286. return s
  287. }
  288. return s[:maxKeep] + "..."
  289. }