| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- package relay
- import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- "one-api/setting/ratio_setting"
- "github.com/gin-gonic/gin"
- )
- /*
- Task 任务通过平台、Action 区分任务
- */
- func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
- platform := constant.TaskPlatform(c.GetString("platform"))
- relayInfo := relaycommon.GenTaskRelayInfo(c)
- adaptor := GetTaskAdaptor(platform)
- if adaptor == nil {
- return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
- // get & validate taskRequest 获取并验证文本请求
- taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
- if taskErr != nil {
- return
- }
- modelName := relayInfo.OriginModelName
- if modelName == "" {
- modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
- }
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- // 预扣
- groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
- var ratio float64
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
- if hasUserGroupRatio {
- ratio = modelPrice * userGroupRatio
- } else {
- ratio = modelPrice * groupRatio
- }
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
- return
- }
- quota := int(ratio * common.QuotaPerUnit)
- if userQuota-quota < 0 {
- taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
- return
- }
- if relayInfo.OriginTaskID != "" {
- originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- if originTask.ChannelId != relayInfo.ChannelId {
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
- return
- }
- if channel.Status != common.ChannelStatusEnabled {
- return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
- }
- c.Set("base_url", channel.GetBaseURL())
- c.Set("channel_id", originTask.ChannelId)
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- relayInfo.BaseUrl = channel.GetBaseURL()
- relayInfo.ChannelId = originTask.ChannelId
- }
- }
- // build body
- requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
- return
- }
- // do request
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- return
- }
- // handle response
- if resp != nil && resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
- return
- }
- defer func() {
- // release quota
- if relayInfo.ConsumeQuota && taskErr == nil {
- err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- gRatio := groupRatio
- if hasUserGroupRatio {
- gRatio = userGroupRatio
- }
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
- other := make(map[string]interface{})
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- if hasUserGroupRatio {
- other["user_group_ratio"] = userGroupRatio
- }
- model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- UserInput: "", // Task任务不记录用户输入
- TokenId: relayInfo.TokenId,
- UserQuota: userQuota,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
- }
- }()
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
- if taskErr != nil {
- return
- }
- relayInfo.ConsumeQuota = true
- // insert task
- task := model.InitTask(platform, relayInfo)
- task.TaskID = taskID
- task.Quota = quota
- task.Data = taskData
- task.Action = relayInfo.Action
- err = task.Insert()
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
- return
- }
- return nil
- }
- var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
- relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
- }
- func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
- respBuilder, ok := fetchRespBuilders[relayMode]
- if !ok {
- taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
- }
- respBody, taskErr := respBuilder(c)
- if taskErr != nil {
- return taskErr
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
- }
- return
- }
- func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- userId := c.GetInt("id")
- var condition = struct {
- IDs []any `json:"ids"`
- Action string `json:"action"`
- }{}
- err := c.BindJSON(&condition)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
- return
- }
- var tasks []any
- if len(condition.IDs) > 0 {
- taskModels, err := model.GetByTaskIds(userId, condition.IDs)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
- return
- }
- for _, task := range taskModels {
- tasks = append(tasks, TaskModel2Dto(task))
- }
- } else {
- tasks = make([]any, 0)
- }
- respBody, err = json.Marshal(dto.TaskResponse[[]any]{
- Code: "success",
- Data: tasks,
- })
- return
- }
- func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("id")
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- return
- }
- func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("task_id")
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- return
- }
- func TaskModel2Dto(task *model.Task) *dto.TaskDto {
- return &dto.TaskDto{
- TaskID: task.TaskID,
- Action: task.Action,
- Status: string(task.Status),
- FailReason: task.FailReason,
- SubmitTime: task.SubmitTime,
- StartTime: task.StartTime,
- FinishTime: task.FinishTime,
- Progress: task.Progress,
- Data: task.Data,
- }
- }
|