|
|
@@ -13,9 +13,9 @@ import (
|
|
|
"one-api/model"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
+ "one-api/relay/helper"
|
|
|
"one-api/service"
|
|
|
"one-api/setting"
|
|
|
- "one-api/setting/ratio_setting"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
@@ -174,24 +174,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
|
|
}
|
|
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
|
|
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
|
|
- // 如果没有配置价格,则使用默认价格
|
|
|
- if !success {
|
|
|
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
|
|
- if !ok {
|
|
|
- modelPrice = 0.1
|
|
|
- } else {
|
|
|
- modelPrice = defaultPrice
|
|
|
- }
|
|
|
- }
|
|
|
- groupRatio := ratio_setting.GetGroupRatio(group)
|
|
|
- var ratio float64
|
|
|
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, group)
|
|
|
- if hasUserGroupRatio {
|
|
|
- ratio = modelPrice * userGroupRatio
|
|
|
- } else {
|
|
|
- ratio = modelPrice * groupRatio
|
|
|
- }
|
|
|
+
|
|
|
+ priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
|
|
+
|
|
|
userQuota, err := model.GetUserQuota(userId, false)
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
@@ -199,9 +184,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
Description: err.Error(),
|
|
|
}
|
|
|
}
|
|
|
- quota := int(ratio * common.QuotaPerUnit)
|
|
|
|
|
|
- if userQuota-quota < 0 {
|
|
|
+ if userQuota-priceData.Quota < 0 {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|
|
|
Description: "quota_not_enough",
|
|
|
@@ -216,27 +200,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
}
|
|
|
defer func() {
|
|
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
|
|
- err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
|
|
+ err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
|
|
if err != nil {
|
|
|
common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
}
|
|
|
- //err = model.CacheUpdateUserQuota(userId)
|
|
|
- // if err != nil {
|
|
|
- // common.SysError("error update user quota cache: " + err.Error())
|
|
|
- // }
|
|
|
- if quota != 0 {
|
|
|
- tokenName := c.GetString("token_name")
|
|
|
- gRatio := groupRatio
|
|
|
- if hasUserGroupRatio {
|
|
|
- gRatio = userGroupRatio
|
|
|
- }
|
|
|
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, constant.MjActionSwapFace)
|
|
|
- other := genMjOtherInfo(modelPrice, groupRatio, userGroupRatio, hasUserGroupRatio)
|
|
|
- model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
|
|
- quota, logContent, tokenId, userQuota, 0, false, group, other)
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
- model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
- }
|
|
|
+
|
|
|
+ tokenName := c.GetString("token_name")
|
|
|
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
|
|
+ other := service.GenerateMjOtherInfo(priceData)
|
|
|
+ model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
|
|
+ priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
|
|
+ model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
|
|
}
|
|
|
}()
|
|
|
midjResponse := &mjResp.Response
|
|
|
@@ -257,7 +232,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
Progress: "0%",
|
|
|
FailReason: "",
|
|
|
ChannelId: c.GetInt("channel_id"),
|
|
|
- Quota: quota,
|
|
|
+ Quota: priceData.Quota,
|
|
|
}
|
|
|
err = midjourneyTask.Insert()
|
|
|
if err != nil {
|
|
|
@@ -487,24 +462,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
|
|
modelName := service.CoverActionToModelName(midjRequest.Action)
|
|
|
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
|
|
- // 如果没有配置价格,则使用默认价格
|
|
|
- if !success {
|
|
|
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
|
|
- if !ok {
|
|
|
- modelPrice = 0.1
|
|
|
- } else {
|
|
|
- modelPrice = defaultPrice
|
|
|
- }
|
|
|
- }
|
|
|
- groupRatio := ratio_setting.GetGroupRatio(group)
|
|
|
- var ratio float64
|
|
|
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, group)
|
|
|
- if hasUserGroupRatio {
|
|
|
- ratio = modelPrice * userGroupRatio
|
|
|
- } else {
|
|
|
- ratio = modelPrice * groupRatio
|
|
|
- }
|
|
|
+
|
|
|
+ priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
|
|
+
|
|
|
userQuota, err := model.GetUserQuota(userId, false)
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
@@ -512,9 +472,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
Description: err.Error(),
|
|
|
}
|
|
|
}
|
|
|
- quota := int(ratio * common.QuotaPerUnit)
|
|
|
|
|
|
- if consumeQuota && userQuota-quota < 0 {
|
|
|
+ if consumeQuota && userQuota-priceData.Quota < 0 {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|
|
|
Description: "quota_not_enough",
|
|
|
@@ -529,23 +488,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
|
|
|
defer func() {
|
|
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
|
|
- err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
|
|
+ err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
|
|
if err != nil {
|
|
|
common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
}
|
|
|
- if quota != 0 {
|
|
|
- tokenName := c.GetString("token_name")
|
|
|
- gRatio := groupRatio
|
|
|
- if hasUserGroupRatio {
|
|
|
- gRatio = userGroupRatio
|
|
|
- }
|
|
|
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, gRatio, midjRequest.Action, midjResponse.Result)
|
|
|
- other := genMjOtherInfo(modelPrice, groupRatio, userGroupRatio, hasUserGroupRatio)
|
|
|
- model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
|
|
- quota, logContent, tokenId, userQuota, 0, false, group, other)
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
- model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
- }
|
|
|
+ tokenName := c.GetString("token_name")
|
|
|
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
|
|
+ other := service.GenerateMjOtherInfo(priceData)
|
|
|
+ model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
|
|
+ priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
|
|
+ model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
@@ -573,7 +526,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
Progress: "0%",
|
|
|
FailReason: "",
|
|
|
ChannelId: c.GetInt("channel_id"),
|
|
|
- Quota: quota,
|
|
|
+ Quota: priceData.Quota,
|
|
|
}
|
|
|
if midjResponse.Code == 3 {
|
|
|
//无实例账号自动禁用渠道(No available account instance)
|
|
|
@@ -673,13 +626,3 @@ func getMjRequestPath(path string) string {
|
|
|
}
|
|
|
return requestURL
|
|
|
}
|
|
|
-
|
|
|
-func genMjOtherInfo(modelPrice, groupRatio, userGroupRatio float64, hasUserGroupRatio bool) map[string]interface{} {
|
|
|
- other := make(map[string]interface{})
|
|
|
- other["model_price"] = modelPrice
|
|
|
- other["group_ratio"] = groupRatio
|
|
|
- if hasUserGroupRatio && userGroupRatio > 0 {
|
|
|
- other["user_group_ratio"] = userGroupRatio
|
|
|
- }
|
|
|
- return other
|
|
|
-}
|