| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- package relay
- import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- relayconstant "github.com/QuantumNous/new-api/relay/constant"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/setting/ratio_setting"
- "github.com/gin-gonic/gin"
- )
- /*
- Task 任务通过平台、Action 区分任务
- */
- func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- info.InitChannelMeta(c)
- // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
- if info.TaskRelayInfo == nil {
- info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
- }
- platform := constant.TaskPlatform(c.GetString("platform"))
- if platform == "" {
- platform = GetTaskPlatform(c)
- }
- info.InitChannelMeta(c)
- adaptor := GetTaskAdaptor(platform)
- if adaptor == nil {
- return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
- }
- adaptor.Init(info)
- // get & validate taskRequest 获取并验证文本请求
- taskErr = adaptor.ValidateRequestAndSetAction(c, info)
- if taskErr != nil {
- return
- }
- modelName := info.OriginModelName
- if modelName == "" {
- modelName = service.CoverTaskActionToModelName(platform, info.Action)
- }
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- // 预扣
- groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
- var ratio float64
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
- if hasUserGroupRatio {
- ratio = modelPrice * userGroupRatio
- } else {
- ratio = modelPrice * groupRatio
- }
- // FIXME: 临时修补,支持任务仅按次计费
- if !common.StringsContains(constant.TaskPricePatches, modelName) {
- if len(info.PriceData.OtherRatios) > 0 {
- for _, ra := range info.PriceData.OtherRatios {
- if 1.0 != ra {
- ratio *= ra
- }
- }
- }
- }
- println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
- userQuota, err := model.GetUserQuota(info.UserId, false)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
- return
- }
- quota := int(ratio * common.QuotaPerUnit)
- if userQuota-quota < 0 {
- taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
- return
- }
- if info.OriginTaskID != "" {
- originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- if originTask.ChannelId != info.ChannelId {
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
- return
- }
- if channel.Status != common.ChannelStatusEnabled {
- return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
- }
- c.Set("base_url", channel.GetBaseURL())
- c.Set("channel_id", originTask.ChannelId)
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- info.ChannelBaseUrl = channel.GetBaseURL()
- info.ChannelId = originTask.ChannelId
- }
- }
- // build body
- requestBody, err := adaptor.BuildRequestBody(c, info)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
- return
- }
- // do request
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- return
- }
- // handle response
- if resp != nil && resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
- return
- }
- defer func() {
- // release quota
- if info.ConsumeQuota && taskErr == nil {
- err := service.PostConsumeQuota(info, quota, 0, true)
- if err != nil {
- common.SysLog("error consuming token remain quota: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- //gRatio := groupRatio
- //if hasUserGroupRatio {
- // gRatio = userGroupRatio
- //}
- logContent := fmt.Sprintf("操作 %s", info.Action)
- // FIXME: 临时修补,支持任务仅按次计费
- if common.StringsContains(constant.TaskPricePatches, modelName) {
- logContent = fmt.Sprintf("%s,按次计费", logContent)
- } else {
- if len(info.PriceData.OtherRatios) > 0 {
- var contents []string
- for key, ra := range info.PriceData.OtherRatios {
- if 1.0 != ra {
- contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
- }
- }
- if len(contents) > 0 {
- logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
- }
- }
- }
- other := make(map[string]interface{})
- if c != nil && c.Request != nil && c.Request.URL != nil {
- other["request_path"] = c.Request.URL.Path
- }
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- if hasUserGroupRatio {
- other["user_group_ratio"] = userGroupRatio
- }
- model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
- ChannelId: info.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: info.TokenId,
- Group: info.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
- model.UpdateChannelUsedQuota(info.ChannelId, quota)
- }
- }
- }()
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
- if taskErr != nil {
- return
- }
- info.ConsumeQuota = true
- // insert task
- task := model.InitTask(platform, info)
- task.TaskID = taskID
- task.Quota = quota
- task.Data = taskData
- task.Action = info.Action
- err = task.Insert()
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
- return
- }
- return nil
- }
- var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
- relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
- }
- func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
- respBuilder, ok := fetchRespBuilders[relayMode]
- if !ok {
- taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
- }
- respBody, taskErr := respBuilder(c)
- if taskErr != nil {
- return taskErr
- }
- if len(respBody) == 0 {
- respBody = []byte("{\"code\":\"success\",\"data\":null}")
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
- }
- return
- }
- func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- userId := c.GetInt("id")
- var condition = struct {
- IDs []any `json:"ids"`
- Action string `json:"action"`
- }{}
- err := c.BindJSON(&condition)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
- return
- }
- var tasks []any
- if len(condition.IDs) > 0 {
- taskModels, err := model.GetByTaskIds(userId, condition.IDs)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
- return
- }
- for _, task := range taskModels {
- tasks = append(tasks, TaskModel2Dto(task))
- }
- } else {
- tasks = make([]any, 0)
- }
- respBody, err = json.Marshal(dto.TaskResponse[[]any]{
- Code: "success",
- Data: tasks,
- })
- return
- }
- func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("id")
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- return
- }
- func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("task_id")
- if taskId == "" {
- taskId = c.GetString("task_id")
- }
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- func() {
- channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
- if err2 != nil {
- return
- }
- if channelModel.Type != constant.ChannelTypeVertexAi {
- return
- }
- baseURL := constant.ChannelBaseURLs[channelModel.Type]
- if channelModel.GetBaseURL() != "" {
- baseURL = channelModel.GetBaseURL()
- }
- adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
- if adaptor == nil {
- return
- }
- resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
- "task_id": originTask.TaskID,
- "action": originTask.Action,
- })
- if err2 != nil || resp == nil {
- return
- }
- defer resp.Body.Close()
- body, err2 := io.ReadAll(resp.Body)
- if err2 != nil {
- return
- }
- ti, err2 := adaptor.ParseTaskResult(body)
- if err2 == nil && ti != nil {
- if ti.Status != "" {
- originTask.Status = model.TaskStatus(ti.Status)
- }
- if ti.Progress != "" {
- originTask.Progress = ti.Progress
- }
- if ti.Url != "" {
- originTask.FailReason = ti.Url
- }
- _ = originTask.Update()
- var raw map[string]any
- _ = json.Unmarshal(body, &raw)
- format := "mp4"
- if respObj, ok := raw["response"].(map[string]any); ok {
- if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
- if v0, ok := vids[0].(map[string]any); ok {
- if mt, ok := v0["mimeType"].(string); ok && mt != "" {
- if strings.Contains(mt, "mp4") {
- format = "mp4"
- } else {
- format = mt
- }
- }
- }
- }
- }
- status := "processing"
- switch originTask.Status {
- case model.TaskStatusSuccess:
- status = "succeeded"
- case model.TaskStatusFailure:
- status = "failed"
- case model.TaskStatusQueued, model.TaskStatusSubmitted:
- status = "queued"
- }
- out := map[string]any{
- "error": nil,
- "format": format,
- "metadata": nil,
- "status": status,
- "task_id": originTask.TaskID,
- "url": originTask.FailReason,
- }
- respBody, _ = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: out,
- })
- }
- }()
- if len(respBody) != 0 {
- return
- }
- if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
- adaptor := GetTaskAdaptor(originTask.Platform)
- if adaptor == nil {
- taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
- return
- }
- if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
- openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
- return
- }
- respBody = openAIVideoData
- return
- }
- taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
- return
- }
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
- }
- return
- }
- func TaskModel2Dto(task *model.Task) *dto.TaskDto {
- return &dto.TaskDto{
- TaskID: task.TaskID,
- Action: task.Action,
- Status: string(task.Status),
- FailReason: task.FailReason,
- SubmitTime: task.SubmitTime,
- StartTime: task.StartTime,
- FinishTime: task.FinishTime,
- Progress: task.Progress,
- Data: task.Data,
- }
- }
|