package relay import ( "bytes" "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/gin-gonic/gin" ) /* Task 任务通过平台、Action 区分任务 */ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { info.InitChannelMeta(c) // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields if info.TaskRelayInfo == nil { info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} } path := c.Request.URL.Path if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { info.Action = constant.TaskActionRemix } // 提取 remix 任务的 video_id if info.Action == constant.TaskActionRemix { videoID := c.Param("video_id") if strings.TrimSpace(videoID) == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest) } info.OriginTaskID = videoID } platform := constant.TaskPlatform(c.GetString("platform")) // 获取原始任务信息 if info.OriginTaskID != "" { originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) return } if !exist { taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) return } if info.OriginModelName == "" { if originTask.Properties.OriginModelName != "" { info.OriginModelName = originTask.Properties.OriginModelName } else if originTask.Properties.UpstreamModelName != "" { info.OriginModelName = originTask.Properties.UpstreamModelName } else { var taskData map[string]interface{} _ = json.Unmarshal(originTask.Data, &taskData) if m, ok := taskData["model"].(string); ok && m != "" { info.OriginModelName = m platform = originTask.Platform } } } if originTask.ChannelId != info.ChannelId { channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) return } if channel.Status != common.ChannelStatusEnabled { taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) return } key, _, newAPIError := channel.GetNextEnabledKey() if newAPIError != nil { taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) return } common.SetContextKey(c, constant.ContextKeyChannelKey, key) common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) info.ChannelBaseUrl = channel.GetBaseURL() info.ChannelId = originTask.ChannelId info.ChannelType = channel.Type info.ApiKey = key platform = originTask.Platform } // 使用原始任务的参数 if info.Action == constant.TaskActionRemix { var taskData map[string]interface{} _ = json.Unmarshal(originTask.Data, &taskData) secondsStr, _ := taskData["seconds"].(string) seconds, _ := strconv.Atoi(secondsStr) if seconds <= 0 { seconds = 4 } sizeStr, _ := taskData["size"].(string) if info.PriceData.OtherRatios == nil { info.PriceData.OtherRatios = map[string]float64{} } info.PriceData.OtherRatios["seconds"] = float64(seconds) info.PriceData.OtherRatios["size"] = 1 if sizeStr == "1792x1024" || sizeStr == "1024x1792" { info.PriceData.OtherRatios["size"] = 1.666667 } } } if platform == "" { platform = GetTaskPlatform(c) } info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } adaptor.Init(info) // get & validate taskRequest 获取并验证文本请求 taskErr = adaptor.ValidateRequestAndSetAction(c, info) if taskErr != nil { return } modelName := info.OriginModelName if modelName == "" { modelName = service.CoverTaskActionToModelName(platform, info.Action) } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] if !ok { modelPrice = 0.1 } else { modelPrice = defaultPrice } } // 预扣 groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) var ratio float64 userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) if hasUserGroupRatio { ratio = modelPrice * userGroupRatio } else { ratio = modelPrice * groupRatio } // FIXME: 临时修补,支持任务仅按次计费 if !common.StringsContains(constant.TaskPricePatches, modelName) { if len(info.PriceData.OtherRatios) > 0 { for _, ra := range info.PriceData.OtherRatios { if 1.0 != ra { ratio *= ra } } } } println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return } quota := int(ratio * common.QuotaPerUnit) if userQuota-quota < 0 { taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) return } // build body requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) return } // do request resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return } // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) return } defer func() { // release quota if info.ConsumeQuota && taskErr == nil { err := service.PostConsumeQuota(info, quota, 0, true) if err != nil { common.SysLog("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") //gRatio := groupRatio //if hasUserGroupRatio { // gRatio = userGroupRatio //} logContent := fmt.Sprintf("操作 %s", info.Action) // FIXME: 临时修补,支持任务仅按次计费 if common.StringsContains(constant.TaskPricePatches, modelName) { logContent = fmt.Sprintf("%s,按次计费", logContent) } else { if len(info.PriceData.OtherRatios) > 0 { var contents []string for key, ra := range info.PriceData.OtherRatios { if 1.0 != ra { contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) } } if len(contents) > 0 { logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) } } } other := make(map[string]interface{}) if c != nil && c.Request != nil && c.Request.URL != nil { other["request_path"] = c.Request.URL.Path } other["model_price"] = modelPrice other["group_ratio"] = groupRatio if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: quota, Content: logContent, TokenId: info.TokenId, Group: info.UsingGroup, Other: other, }) model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) model.UpdateChannelUsedQuota(info.ChannelId, quota) } } }() taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return } info.ConsumeQuota = true // insert task task := model.InitTask(platform, info) task.TaskID = taskID task.Quota = quota task.Data = taskData task.Action = info.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) return } return nil } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { respBuilder, ok := fetchRespBuilders[relayMode] if !ok { taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) } respBody, taskErr := respBuilder(c) if taskErr != nil { return taskErr } if len(respBody) == 0 { respBody = []byte("{\"code\":\"success\",\"data\":null}") } c.Writer.Header().Set("Content-Type", "application/json") _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) if err != nil { taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) return } return } func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { userId := c.GetInt("id") var condition = struct { IDs []any `json:"ids"` Action string `json:"action"` }{} err := c.BindJSON(&condition) if err != nil { taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) return } var tasks []any if len(condition.IDs) > 0 { taskModels, err := model.GetByTaskIds(userId, condition.IDs) if err != nil { taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) return } for _, task := range taskModels { tasks = append(tasks, TaskModel2Dto(task)) } } else { tasks = make([]any, 0) } respBody, err = json.Marshal(dto.TaskResponse[[]any]{ Code: "success", Data: tasks, }) return } func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { taskId := c.Param("id") userId := c.GetInt("id") originTask, exist, err := model.GetByTaskId(userId, taskId) if err != nil { taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) return } if !exist { taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) return } respBody, err = json.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) return } func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { taskId := c.Param("task_id") if taskId == "" { taskId = c.GetString("task_id") } userId := c.GetInt("id") originTask, exist, err := model.GetByTaskId(userId, taskId) if err != nil { taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) return } if !exist { taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) return } func() { channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) if err2 != nil { return } if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { return } baseURL := constant.ChannelBaseURLs[channelModel.Type] if channelModel.GetBaseURL() != "" { baseURL = channelModel.GetBaseURL() } proxy := channelModel.GetSetting().Proxy adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) if adaptor == nil { return } resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ "task_id": originTask.TaskID, "action": originTask.Action, }, proxy) if err2 != nil || resp == nil { return } defer resp.Body.Close() body, err2 := io.ReadAll(resp.Body) if err2 != nil { return } ti, err2 := adaptor.ParseTaskResult(body) if err2 == nil && ti != nil { if ti.Status != "" { originTask.Status = model.TaskStatus(ti.Status) } if ti.Progress != "" { originTask.Progress = ti.Progress } if ti.Url != "" { if strings.HasPrefix(ti.Url, "data:") { } else { originTask.FailReason = ti.Url } } _ = originTask.Update() var raw map[string]any _ = json.Unmarshal(body, &raw) format := "mp4" if respObj, ok := raw["response"].(map[string]any); ok { if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { if v0, ok := vids[0].(map[string]any); ok { if mt, ok := v0["mimeType"].(string); ok && mt != "" { if strings.Contains(mt, "mp4") { format = "mp4" } else { format = mt } } } } } status := "processing" switch originTask.Status { case model.TaskStatusSuccess: status = "succeeded" case model.TaskStatusFailure: status = "failed" case model.TaskStatusQueued, model.TaskStatusSubmitted: status = "queued" } if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { out := map[string]any{ "error": nil, "format": format, "metadata": nil, "status": status, "task_id": originTask.TaskID, "url": originTask.FailReason, } respBody, _ = json.Marshal(dto.TaskResponse[any]{ Code: "success", Data: out, }) } } }() if len(respBody) != 0 { return } if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { adaptor := GetTaskAdaptor(originTask.Platform) if adaptor == nil { taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) return } if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok { openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask) if err != nil { taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError) return } respBody = openAIVideoData return } taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) return } respBody, err = json.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) if err != nil { taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError) } return } func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ TaskID: task.TaskID, Action: task.Action, Status: string(task.Status), FailReason: task.FailReason, SubmitTime: task.SubmitTime, StartTime: task.StartTime, FinishTime: task.FinishTime, Progress: task.Progress, Data: task.Data, } }