|
|
@@ -10,7 +10,6 @@ import (
|
|
|
"one-api/common"
|
|
|
"one-api/constant"
|
|
|
"one-api/dto"
|
|
|
- "one-api/logger"
|
|
|
"one-api/model"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
@@ -171,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
- startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
|
|
- tokenId := c.GetInt("token_id")
|
|
|
- userId := c.GetInt("id")
|
|
|
- //group := c.GetString("group")
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- relayInfo := relaycommon.GenRelayInfo(c)
|
|
|
+func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
|
|
var swapFaceRequest dto.SwapFaceRequest
|
|
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
|
|
if err != nil {
|
|
|
@@ -188,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
}
|
|
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
|
|
|
|
|
- priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
|
|
+ priceData := helper.ModelPriceHelperPerCall(c, info)
|
|
|
|
|
|
- userQuota, err := model.GetUserQuota(userId, false)
|
|
|
+ userQuota, err := model.GetUserQuota(info.UserId, false)
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|
|
|
@@ -213,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
}
|
|
|
defer func() {
|
|
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
|
|
- err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
|
|
+ err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
|
|
|
if err != nil {
|
|
|
- logger.SysError("error consuming token remain quota: " + err.Error())
|
|
|
+ common.SysLog("error consuming token remain quota: " + err.Error())
|
|
|
}
|
|
|
|
|
|
tokenName := c.GetString("token_name")
|
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
|
|
other := service.GenerateMjOtherInfo(priceData)
|
|
|
- model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
|
|
- ChannelId: channelId,
|
|
|
+ model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
|
|
+ ChannelId: info.ChannelId,
|
|
|
ModelName: modelName,
|
|
|
TokenName: tokenName,
|
|
|
Quota: priceData.Quota,
|
|
|
Content: logContent,
|
|
|
- TokenId: tokenId,
|
|
|
- UserQuota: userQuota,
|
|
|
- Group: relayInfo.UsingGroup,
|
|
|
+ TokenId: info.TokenId,
|
|
|
+ Group: info.UsingGroup,
|
|
|
Other: other,
|
|
|
})
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
|
|
- model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
|
|
|
+ model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
|
|
|
}
|
|
|
}()
|
|
|
midjResponse := &mjResp.Response
|
|
|
midjourneyTask := &model.Midjourney{
|
|
|
- UserId: userId,
|
|
|
+ UserId: info.UserId,
|
|
|
Code: midjResponse.Code,
|
|
|
Action: constant.MjActionSwapFace,
|
|
|
MjId: midjResponse.Result,
|
|
|
@@ -246,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
PromptEn: "",
|
|
|
Description: midjResponse.Description,
|
|
|
State: "",
|
|
|
- SubmitTime: startTime,
|
|
|
+ SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
|
|
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
|
|
FinishTime: 0,
|
|
|
ImageUrl: "",
|
|
|
@@ -370,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
|
|
-
|
|
|
- //tokenId := c.GetInt("token_id")
|
|
|
- //channelType := c.GetInt("channel")
|
|
|
- userId := c.GetInt("id")
|
|
|
- group := c.GetString("group")
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- relayInfo := relaycommon.GenRelayInfo(c)
|
|
|
+func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
|
|
consumeQuota := true
|
|
|
var midjRequest dto.MidjourneyRequest
|
|
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
|
|
@@ -385,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
|
|
}
|
|
|
|
|
|
- if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
|
|
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
|
|
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
|
|
if mjErr != nil {
|
|
|
return mjErr
|
|
|
}
|
|
|
- relayMode = relayconstant.RelayModeMidjourneyChange
|
|
|
+ relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
|
|
|
}
|
|
|
- if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
|
|
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
|
|
midjRequest.Action = constant.MjActionVideo
|
|
|
}
|
|
|
|
|
|
- if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
|
|
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
|
|
if midjRequest.Prompt == "" {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
|
|
}
|
|
|
midjRequest.Action = constant.MjActionImagine
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
|
|
midjRequest.Action = constant.MjActionDescribe
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
|
|
midjRequest.Action = constant.MjActionEdits
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
|
|
midjRequest.Action = constant.MjActionShorten
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
|
|
midjRequest.Action = constant.MjActionBlend
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
|
|
midjRequest.Action = constant.MjActionUpload
|
|
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
|
|
mjId := ""
|
|
|
- if relayMode == relayconstant.RelayModeMidjourneyChange {
|
|
|
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
|
|
|
if midjRequest.TaskId == "" {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
|
|
} else if midjRequest.Action == "" {
|
|
|
@@ -423,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
//action = midjRequest.Action
|
|
|
mjId = midjRequest.TaskId
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
|
|
if midjRequest.Content == "" {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
|
|
}
|
|
|
@@ -433,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
mjId = params.TaskId
|
|
|
midjRequest.Action = params.Action
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
|
|
|
//if midjRequest.MaskBase64 == "" {
|
|
|
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
|
|
//}
|
|
|
mjId = midjRequest.TaskId
|
|
|
midjRequest.Action = constant.MjActionModal
|
|
|
- } else if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
|
|
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
|
|
midjRequest.Action = constant.MjActionVideo
|
|
|
if midjRequest.TaskId == "" {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
|
|
@@ -449,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
mjId = midjRequest.TaskId
|
|
|
}
|
|
|
|
|
|
- originTask := model.GetByMJId(userId, mjId)
|
|
|
+ originTask := model.GetByMJId(relayInfo.UserId, mjId)
|
|
|
if originTask == nil {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
|
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
|
|
if setting.MjActionCheckSuccessEnabled {
|
|
|
- if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
|
|
+ if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
|
|
}
|
|
|
}
|
|
|
@@ -497,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
|
|
|
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
|
|
|
|
|
- userQuota, err := model.GetUserQuota(userId, false)
|
|
|
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|
|
|
@@ -522,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
|
|
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
|
|
if err != nil {
|
|
|
- logger.SysError("error consuming token remain quota: " + err.Error())
|
|
|
+ common.SysLog("error consuming token remain quota: " + err.Error())
|
|
|
}
|
|
|
tokenName := c.GetString("token_name")
|
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
|
|
other := service.GenerateMjOtherInfo(priceData)
|
|
|
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
|
|
- ChannelId: channelId,
|
|
|
+ ChannelId: relayInfo.ChannelId,
|
|
|
ModelName: modelName,
|
|
|
TokenName: tokenName,
|
|
|
Quota: priceData.Quota,
|
|
|
Content: logContent,
|
|
|
TokenId: relayInfo.TokenId,
|
|
|
- UserQuota: userQuota,
|
|
|
- Group: group,
|
|
|
+ Group: relayInfo.UsingGroup,
|
|
|
Other: other,
|
|
|
})
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
|
|
- model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
|
|
|
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
@@ -551,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
|
|
// other: 提交错误,description为错误描述
|
|
|
midjourneyTask := &model.Midjourney{
|
|
|
- UserId: userId,
|
|
|
+ UserId: relayInfo.UserId,
|
|
|
Code: midjResponse.Code,
|
|
|
Action: midjRequest.Action,
|
|
|
MjId: midjResponse.Result,
|
|
|
@@ -573,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
//无实例账号自动禁用渠道(No available account instance)
|
|
|
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
|
|
if err != nil {
|
|
|
- logger.SysError("get_channel_null: " + err.Error())
|
|
|
+ common.SysLog("get_channel_null: " + err.Error())
|
|
|
}
|
|
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
|
|
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|