|
|
@@ -22,31 +22,27 @@ import (
|
|
|
/*
|
|
|
Task 任务通过平台、Action 区分任务
|
|
|
*/
|
|
|
-func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
|
|
+ info.InitChannelMeta(c)
|
|
|
platform := constant.TaskPlatform(c.GetString("platform"))
|
|
|
if platform == "" {
|
|
|
platform = GetTaskPlatform(c)
|
|
|
}
|
|
|
|
|
|
- relayInfo, err := relaycommon.GenTaskRelayInfo(c)
|
|
|
- if err != nil {
|
|
|
- return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
|
|
|
- }
|
|
|
-
|
|
|
adaptor := GetTaskAdaptor(platform)
|
|
|
if adaptor == nil {
|
|
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
|
|
}
|
|
|
- adaptor.Init(relayInfo)
|
|
|
+ adaptor.Init(info)
|
|
|
// get & validate taskRequest 获取并验证文本请求
|
|
|
- taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
|
|
|
+ taskErr = adaptor.ValidateRequestAndSetAction(c, info)
|
|
|
if taskErr != nil {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- modelName := relayInfo.OriginModelName
|
|
|
+ modelName := info.OriginModelName
|
|
|
if modelName == "" {
|
|
|
- modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
|
|
+ modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
|
|
}
|
|
|
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
|
|
if !success {
|
|
|
@@ -59,15 +55,15 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
}
|
|
|
|
|
|
// 预扣
|
|
|
- groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
|
|
+ groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
|
|
var ratio float64
|
|
|
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
|
|
+ userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
|
|
|
if hasUserGroupRatio {
|
|
|
ratio = modelPrice * userGroupRatio
|
|
|
} else {
|
|
|
ratio = modelPrice * groupRatio
|
|
|
}
|
|
|
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
|
|
+ userQuota, err := model.GetUserQuota(info.UserId, false)
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
|
return
|
|
|
@@ -78,8 +74,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if relayInfo.OriginTaskID != "" {
|
|
|
- originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
|
|
|
+ if info.OriginTaskID != "" {
|
|
|
+ originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
|
|
return
|
|
|
@@ -88,7 +84,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
- if originTask.ChannelId != relayInfo.ChannelId {
|
|
|
+ if originTask.ChannelId != info.ChannelId {
|
|
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
|
|
@@ -101,19 +97,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
|
|
|
|
- relayInfo.ChannelBaseUrl = channel.GetBaseURL()
|
|
|
- relayInfo.ChannelId = originTask.ChannelId
|
|
|
+ info.ChannelBaseUrl = channel.GetBaseURL()
|
|
|
+ info.ChannelId = originTask.ChannelId
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// build body
|
|
|
- requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
|
|
|
+ requestBody, err := adaptor.BuildRequestBody(c, info)
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
|
|
return
|
|
|
}
|
|
|
// do request
|
|
|
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
|
|
+ resp, err := adaptor.DoRequest(c, info, requestBody)
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
return
|
|
|
@@ -127,9 +123,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
|
|
|
defer func() {
|
|
|
// release quota
|
|
|
- if relayInfo.ConsumeQuota && taskErr == nil {
|
|
|
+ if info.ConsumeQuota && taskErr == nil {
|
|
|
|
|
|
- err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
|
|
+ err := service.PostConsumeQuota(info, quota, 0, true)
|
|
|
if err != nil {
|
|
|
common.SysLog("error consuming token remain quota: " + err.Error())
|
|
|
}
|
|
|
@@ -139,40 +135,40 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
|
if hasUserGroupRatio {
|
|
|
gRatio = userGroupRatio
|
|
|
}
|
|
|
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
|
|
|
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
|
|
|
other := make(map[string]interface{})
|
|
|
other["model_price"] = modelPrice
|
|
|
other["group_ratio"] = groupRatio
|
|
|
if hasUserGroupRatio {
|
|
|
other["user_group_ratio"] = userGroupRatio
|
|
|
}
|
|
|
- model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
|
|
- ChannelId: relayInfo.ChannelId,
|
|
|
+ model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
|
|
+ ChannelId: info.ChannelId,
|
|
|
ModelName: modelName,
|
|
|
TokenName: tokenName,
|
|
|
Quota: quota,
|
|
|
Content: logContent,
|
|
|
- TokenId: relayInfo.TokenId,
|
|
|
- Group: relayInfo.UsingGroup,
|
|
|
+ TokenId: info.TokenId,
|
|
|
+ Group: info.UsingGroup,
|
|
|
Other: other,
|
|
|
})
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
|
|
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
|
|
|
+ model.UpdateChannelUsedQuota(info.ChannelId, quota)
|
|
|
}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
|
|
+ taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
|
|
if taskErr != nil {
|
|
|
return
|
|
|
}
|
|
|
- relayInfo.ConsumeQuota = true
|
|
|
+ info.ConsumeQuota = true
|
|
|
// insert task
|
|
|
- task := model.InitTask(platform, relayInfo)
|
|
|
+ task := model.InitTask(platform, info)
|
|
|
task.TaskID = taskID
|
|
|
task.Quota = quota
|
|
|
task.Data = taskData
|
|
|
- task.Action = relayInfo.Action
|
|
|
+ task.Action = info.Action
|
|
|
err = task.Insert()
|
|
|
if err != nil {
|
|
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|