| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- package controller
- import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/setting/ratio_setting"
- )
- func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
- }
- }
- return nil
- }
- func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- cacheGetChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if errUpdate != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
- }
- return fmt.Errorf("CacheGetChannel failed: %w", err)
- }
- adaptor := relay.GetTaskAdaptor(platform)
- if adaptor == nil {
- return fmt.Errorf("video adaptor not found")
- }
- info := &relaycommon.RelayInfo{}
- info.ChannelMeta = &relaycommon.ChannelMeta{
- ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
- }
- adaptor.Init(info)
- for _, taskId := range taskIds {
- if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
- }
- }
- return nil
- }
- func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
- baseURL := constant.ChannelBaseURLs[channel.Type]
- if channel.GetBaseURL() != "" {
- baseURL = channel.GetBaseURL()
- }
- task := taskM[taskId]
- if task == nil {
- logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
- return fmt.Errorf("task %s not found", taskId)
- }
- resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
- "task_id": taskId,
- "action": task.Action,
- })
- if err != nil {
- return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
- }
- //if resp.StatusCode != http.StatusOK {
- //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
- //}
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
- }
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
- taskResult := &relaycommon.TaskInfo{}
- // try parse as New API response format
- var responseItems dto.TaskResponse[model.Task]
- if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
- t := responseItems.Data
- taskResult.TaskID = t.TaskID
- taskResult.Status = string(t.Status)
- taskResult.Url = t.FailReason
- taskResult.Progress = t.Progress
- taskResult.Reason = t.FailReason
- task.Data = t.Data
- } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
- return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
- } else {
- task.Data = redactVideoResponseBody(responseBody)
- }
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
- now := time.Now().Unix()
- if taskResult.Status == "" {
- //return fmt.Errorf("task %s status is empty", taskId)
- taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
- }
- // 记录原本的状态,防止重复退款
- shouldRefund := false
- quota := task.Quota
- preStatus := task.Status
- task.Status = model.TaskStatus(taskResult.Status)
- switch taskResult.Status {
- case model.TaskStatusSubmitted:
- task.Progress = "10%"
- case model.TaskStatusQueued:
- task.Progress = "20%"
- case model.TaskStatusInProgress:
- task.Progress = "30%"
- if task.StartTime == 0 {
- task.StartTime = now
- }
- case model.TaskStatusSuccess:
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
- task.FailReason = taskResult.Url
- }
- // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
- if taskResult.TotalTokens > 0 {
- // 获取模型名称
- var taskData map[string]interface{}
- if err := json.Unmarshal(task.Data, &taskData); err == nil {
- if modelName, ok := taskData["model"].(string); ok && modelName != "" {
- // 获取模型价格和倍率
- modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
- // 只有配置了倍率(非固定价格)时才按 token 重新计费
- if hasRatioSetting && modelRatio > 0 {
- // 获取用户和组的倍率信息
- group := task.Group
- if group == "" {
- user, err := model.GetUserById(task.UserId, false)
- if err == nil {
- group = user.Group
- }
- }
- if group != "" {
- groupRatio := ratio_setting.GetGroupRatio(group)
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
- var finalGroupRatio float64
- if hasUserGroupRatio {
- finalGroupRatio = userGroupRatio
- } else {
- finalGroupRatio = groupRatio
- }
- // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
- actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
- // 计算差额
- preConsumedQuota := task.Quota
- quotaDelta := actualQuota - preConsumedQuota
- if quotaDelta > 0 {
- // 需要补扣费
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
- task.TaskID,
- logger.LogQuota(quotaDelta),
- logger.LogQuota(actualQuota),
- logger.LogQuota(preConsumedQuota),
- taskResult.TotalTokens,
- ))
- if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
- logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
- model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
- task.Quota = actualQuota // 更新任务记录的实际扣费额度
- // 记录消费日志
- logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
- modelRatio, finalGroupRatio, taskResult.TotalTokens,
- logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- } else if quotaDelta < 0 {
- // 需要退还多扣的费用
- refundQuota := -quotaDelta
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
- task.TaskID,
- logger.LogQuota(refundQuota),
- logger.LogQuota(actualQuota),
- logger.LogQuota(preConsumedQuota),
- taskResult.TotalTokens,
- ))
- if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
- logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
- } else {
- task.Quota = actualQuota // 更新任务记录的实际扣费额度
- // 记录退款日志
- logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
- modelRatio, finalGroupRatio, taskResult.TotalTokens,
- logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- } else {
- // quotaDelta == 0, 预扣费刚好准确
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
- task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
- }
- }
- }
- }
- }
- }
- case model.TaskStatusFailure:
- logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
- task.Status = model.TaskStatusFailure
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- task.FailReason = taskResult.Reason
- logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
- taskResult.Progress = "100%"
- if quota != 0 {
- if preStatus != model.TaskStatusFailure {
- shouldRefund = true
- } else {
- logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
- }
- }
- default:
- return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
- }
- if taskResult.Progress != "" {
- task.Progress = taskResult.Progress
- }
- if err := task.Update(); err != nil {
- common.SysLog("UpdateVideoTask task error: " + err.Error())
- shouldRefund = false
- }
- if shouldRefund {
- // 任务失败且之前状态不是失败才退还额度,防止重复退还
- if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
- logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- return nil
- }
- func redactVideoResponseBody(body []byte) []byte {
- var m map[string]any
- if err := json.Unmarshal(body, &m); err != nil {
- return body
- }
- resp, _ := m["response"].(map[string]any)
- if resp != nil {
- delete(resp, "bytesBase64Encoded")
- if v, ok := resp["video"].(string); ok {
- resp["video"] = truncateBase64(v)
- }
- if vs, ok := resp["videos"].([]any); ok {
- for i := range vs {
- if vm, ok := vs[i].(map[string]any); ok {
- delete(vm, "bytesBase64Encoded")
- }
- }
- }
- }
- b, err := json.Marshal(m)
- if err != nil {
- return body
- }
- return b
- }
- func truncateBase64(s string) string {
- const maxKeep = 256
- if len(s) <= maxKeep {
- return s
- }
- return s[:maxKeep] + "..."
- }
|