|
|
@@ -44,7 +44,7 @@ func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
|
|
|
return groupConsumeLevelRatio
|
|
|
}
|
|
|
|
|
|
-func getGroupPMRatio(group *model.GroupCache) (float64, float64) {
|
|
|
+func getGroupPMRatio(group model.GroupCache) (float64, float64) {
|
|
|
groupRPMRatio := group.RPMRatio
|
|
|
if groupRPMRatio <= 0 {
|
|
|
groupRPMRatio = 1
|
|
|
@@ -56,7 +56,7 @@ func getGroupPMRatio(group *model.GroupCache) (float64, float64) {
|
|
|
return groupRPMRatio, groupTPMRatio
|
|
|
}
|
|
|
|
|
|
-func GetGroupAdjustedModelConfig(group *model.GroupCache, mc model.ModelConfig) model.ModelConfig {
|
|
|
+func GetGroupAdjustedModelConfig(group model.GroupCache, mc model.ModelConfig) model.ModelConfig {
|
|
|
if groupModelConfig, ok := group.ModelConfigs[mc.Model]; ok {
|
|
|
mc = mc.LoadFromGroupModelConfig(groupModelConfig)
|
|
|
}
|
|
|
@@ -96,7 +96,7 @@ func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
|
|
|
c.Header(XRateLimitResetTokens, "1m0s")
|
|
|
}
|
|
|
|
|
|
-func UpdateGroupModelRequest(c *gin.Context, group *model.GroupCache, rpm, rps int64) {
|
|
|
+func UpdateGroupModelRequest(c *gin.Context, group model.GroupCache, rpm, rps int64) {
|
|
|
if group.Status == model.GroupStatusInternal {
|
|
|
return
|
|
|
}
|
|
|
@@ -106,7 +106,7 @@ func UpdateGroupModelRequest(c *gin.Context, group *model.GroupCache, rpm, rps i
|
|
|
log.Data["group_rps"] = strconv.FormatInt(rps, 10)
|
|
|
}
|
|
|
|
|
|
-func UpdateGroupModelTokensRequest(c *gin.Context, group *model.GroupCache, tpm, tps int64) {
|
|
|
+func UpdateGroupModelTokensRequest(c *gin.Context, group model.GroupCache, tpm, tps int64) {
|
|
|
if group.Status == model.GroupStatusInternal {
|
|
|
return
|
|
|
}
|
|
|
@@ -134,7 +134,7 @@ func UpdateGroupModelTokennameTokensRequest(c *gin.Context, tpm, tps int64) {
|
|
|
|
|
|
func checkGroupModelRPMAndTPM(
|
|
|
c *gin.Context,
|
|
|
- group *model.GroupCache,
|
|
|
+ group model.GroupCache,
|
|
|
mc model.ModelConfig,
|
|
|
tokenName string,
|
|
|
) error {
|
|
|
@@ -226,7 +226,7 @@ func GetGroupBalanceConsumerFromContext(c *gin.Context) *GroupBalanceConsumer {
|
|
|
|
|
|
func GetGroupBalanceConsumer(
|
|
|
c *gin.Context,
|
|
|
- group *model.GroupCache,
|
|
|
+ group model.GroupCache,
|
|
|
) (*GroupBalanceConsumer, error) {
|
|
|
gbc := GetGroupBalanceConsumerFromContext(c)
|
|
|
if gbc != nil {
|
|
|
@@ -243,7 +243,7 @@ func GetGroupBalanceConsumer(
|
|
|
}
|
|
|
} else {
|
|
|
log := GetLogger(c)
|
|
|
- groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), *group)
|
|
|
+ groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -267,7 +267,7 @@ const (
|
|
|
GroupBalanceNotEnough = "group_balance_not_enough"
|
|
|
)
|
|
|
|
|
|
-func checkGroupBalance(c *gin.Context, group *model.GroupCache) bool {
|
|
|
+func checkGroupBalance(c *gin.Context, group model.GroupCache) bool {
|
|
|
gbc, err := GetGroupBalanceConsumer(c, group)
|
|
|
if err != nil {
|
|
|
if errors.Is(err, balance.ErrNoRealNameUsedAmountLimit) {
|
|
|
@@ -327,44 +327,6 @@ func NewDistribute(mode mode.Mode) gin.HandlerFunc {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-const (
|
|
|
- AIProxyChannelHeader = "Aiproxy-Channel"
|
|
|
-)
|
|
|
-
|
|
|
-func getChannelFromHeader(
|
|
|
- header string,
|
|
|
- mc *model.ModelCaches,
|
|
|
- availableSet []string,
|
|
|
- model string,
|
|
|
-) (*model.Channel, error) {
|
|
|
- channelIDInt, err := strconv.ParseInt(header, 10, 64)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- for _, set := range availableSet {
|
|
|
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
|
|
|
- if len(enabledChannels) > 0 {
|
|
|
- for _, channel := range enabledChannels {
|
|
|
- if int64(channel.ID) == channelIDInt {
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
|
|
|
- if len(disabledChannels) > 0 {
|
|
|
- for _, channel := range disabledChannels {
|
|
|
- if int64(channel.ID) == channelIDInt {
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
|
|
|
-}
|
|
|
-
|
|
|
func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
|
|
|
if modelMode == mode.Unknown {
|
|
|
return true
|
|
|
@@ -377,6 +339,10 @@ func CheckRelayMode(requestMode, modelMode mode.Mode) bool {
|
|
|
case mode.ImagesGenerations, mode.ImagesEdits:
|
|
|
return modelMode == mode.ImagesGenerations ||
|
|
|
modelMode == mode.ImagesEdits
|
|
|
+ case mode.VideoGenerationsJobs, mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent:
|
|
|
+ return modelMode == mode.VideoGenerationsJobs ||
|
|
|
+ modelMode == mode.VideoGenerationsGetJobs ||
|
|
|
+ modelMode == mode.VideoGenerationsContent
|
|
|
default:
|
|
|
return requestMode == modelMode
|
|
|
}
|
|
|
@@ -393,12 +359,13 @@ func distribute(c *gin.Context, mode mode.Mode) {
|
|
|
log := GetLogger(c)
|
|
|
|
|
|
group := GetGroup(c)
|
|
|
+ token := GetToken(c)
|
|
|
|
|
|
if !checkGroupBalance(c, group) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- requestModel, err := getRequestModel(c, mode)
|
|
|
+ requestModel, err := getRequestModel(c, mode, group.ID, token.ID)
|
|
|
if err != nil {
|
|
|
AbortLogWithMessage(
|
|
|
c,
|
|
|
@@ -432,29 +399,17 @@ func distribute(c *gin.Context, mode mode.Mode) {
|
|
|
}
|
|
|
c.Set(ModelConfig, mc)
|
|
|
|
|
|
- if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); group.Status == model.GroupStatusInternal &&
|
|
|
- channelHeader != "" {
|
|
|
- channel, err := getChannelFromHeader(
|
|
|
- channelHeader,
|
|
|
- GetModelCaches(c),
|
|
|
- group.GetAvailableSets(),
|
|
|
- requestModel,
|
|
|
+ if !token.ContainsModel(requestModel) {
|
|
|
+ AbortLogWithMessage(
|
|
|
+ c,
|
|
|
+ http.StatusNotFound,
|
|
|
+ fmt.Sprintf(
|
|
|
+ "The model `%s` does not exist or you do not have access to it.",
|
|
|
+ requestModel,
|
|
|
+ ),
|
|
|
+ "model_not_found",
|
|
|
)
|
|
|
- if err != nil {
|
|
|
- AbortLogWithMessage(c, http.StatusBadRequest, err.Error())
|
|
|
- return
|
|
|
- }
|
|
|
- c.Set(Channel, channel)
|
|
|
- } else {
|
|
|
- token := GetToken(c)
|
|
|
- if !token.ContainsModel(requestModel) {
|
|
|
- AbortLogWithMessage(c,
|
|
|
- http.StatusNotFound,
|
|
|
- fmt.Sprintf("The model `%s` does not exist or you do not have access to it.", requestModel),
|
|
|
- "model_not_found",
|
|
|
- )
|
|
|
- return
|
|
|
- }
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
user, err := getRequestUser(c, mode)
|
|
|
@@ -481,8 +436,6 @@ func distribute(c *gin.Context, mode mode.Mode) {
|
|
|
}
|
|
|
c.Set(RequestMetadata, metadata)
|
|
|
|
|
|
- token := GetToken(c)
|
|
|
-
|
|
|
if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
|
|
|
errMsg := err.Error()
|
|
|
consume.AsyncConsume(
|
|
|
@@ -542,6 +495,18 @@ func GetRequestUser(c *gin.Context) string {
|
|
|
return c.GetString(RequestUser)
|
|
|
}
|
|
|
|
|
|
+func GetChannelID(c *gin.Context) int {
|
|
|
+ return c.GetInt(ChannelID)
|
|
|
+}
|
|
|
+
|
|
|
+func GetJobID(c *gin.Context) string {
|
|
|
+ return c.GetString(JobID)
|
|
|
+}
|
|
|
+
|
|
|
+func GetGenerationID(c *gin.Context) string {
|
|
|
+ return c.GetString(GenerationID)
|
|
|
+}
|
|
|
+
|
|
|
func GetRequestMetadata(c *gin.Context) map[string]string {
|
|
|
return c.GetStringMapString(RequestMetadata)
|
|
|
}
|
|
|
@@ -565,6 +530,8 @@ func NewMetaByContext(c *gin.Context,
|
|
|
modelName := GetRequestModel(c)
|
|
|
modelConfig := GetModelConfig(c)
|
|
|
requestAt := GetRequestAt(c)
|
|
|
+ jobID := GetJobID(c)
|
|
|
+ generationID := GetGenerationID(c)
|
|
|
|
|
|
opts = append(
|
|
|
opts,
|
|
|
@@ -573,6 +540,8 @@ func NewMetaByContext(c *gin.Context,
|
|
|
meta.WithGroup(group),
|
|
|
meta.WithToken(token),
|
|
|
meta.WithEndpoint(c.Request.URL.Path),
|
|
|
+ meta.WithJobID(jobID),
|
|
|
+ meta.WithGenerationID(generationID),
|
|
|
)
|
|
|
|
|
|
return meta.NewMeta(
|
|
|
@@ -585,7 +554,7 @@ func NewMetaByContext(c *gin.Context,
|
|
|
}
|
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
|
-func getRequestModel(c *gin.Context, m mode.Mode) (string, error) {
|
|
|
+func getRequestModel(c *gin.Context, m mode.Mode, groupID string, tokenID int) (string, error) {
|
|
|
path := c.Request.URL.Path
|
|
|
switch {
|
|
|
case m == mode.ParsePdf:
|
|
|
@@ -605,6 +574,30 @@ func getRequestModel(c *gin.Context, m mode.Mode) (string, error) {
|
|
|
// /engines/:model/embeddings
|
|
|
return c.Param("model"), nil
|
|
|
|
|
|
+ case m == mode.VideoGenerationsGetJobs:
|
|
|
+ jobID := c.Param("id")
|
|
|
+ store, err := model.CacheGetStore(jobID)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("get request model failed: %w", err)
|
|
|
+ }
|
|
|
+ if err := validateStoreGroupAndToken(store, groupID, tokenID); err != nil {
|
|
|
+ return "", fmt.Errorf("validate store group and token failed: %w", err)
|
|
|
+ }
|
|
|
+ c.Set(JobID, store.ID)
|
|
|
+ c.Set(ChannelID, store.ChannelID)
|
|
|
+ return store.Model, nil
|
|
|
+ case m == mode.VideoGenerationsContent:
|
|
|
+ generationID := c.Param("id")
|
|
|
+ store, err := model.CacheGetStore(generationID)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("get request model failed: %w", err)
|
|
|
+ }
|
|
|
+ if err := validateStoreGroupAndToken(store, groupID, tokenID); err != nil {
|
|
|
+ return "", fmt.Errorf("validate store group and token failed: %w", err)
|
|
|
+ }
|
|
|
+ c.Set(GenerationID, store.ID)
|
|
|
+ c.Set(ChannelID, store.ChannelID)
|
|
|
+ return store.Model, nil
|
|
|
default:
|
|
|
body, err := common.GetRequestBody(c.Request)
|
|
|
if err != nil {
|
|
|
@@ -614,6 +607,16 @@ func getRequestModel(c *gin.Context, m mode.Mode) (string, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func validateStoreGroupAndToken(store *model.StoreCache, groupID string, tokenID int) error {
|
|
|
+ if store.GroupID != groupID {
|
|
|
+ return fmt.Errorf("store group id mismatch: %s != %s", store.GroupID, groupID)
|
|
|
+ }
|
|
|
+ if store.TokenID != tokenID {
|
|
|
+ return fmt.Errorf("store token id mismatch: %d != %d", store.TokenID, tokenID)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func GetModelFromJSON(body []byte) (string, error) {
|
|
|
node, err := sonic.GetWithOptions(body, ast.SearchOptions{}, "model")
|
|
|
if err != nil {
|