|
|
@@ -10,7 +10,6 @@ import (
|
|
|
"one-api/model"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
- "one-api/setting"
|
|
|
"one-api/setting/ratio_setting"
|
|
|
"one-api/types"
|
|
|
"strconv"
|
|
|
@@ -27,14 +26,6 @@ type ModelRequest struct {
|
|
|
|
|
|
func Distribute() func(c *gin.Context) {
|
|
|
return func(c *gin.Context) {
|
|
|
- allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
|
|
- if len(allowIpsMap) != 0 {
|
|
|
- clientIp := c.ClientIP()
|
|
|
- if _, ok := allowIpsMap[clientIp]; !ok {
|
|
|
- abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
|
|
- return
|
|
|
- }
|
|
|
- }
|
|
|
var channel *model.Channel
|
|
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
|
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
|
|
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
|
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
|
|
return
|
|
|
}
|
|
|
- userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
|
|
- tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
|
|
- if tokenGroup != "" {
|
|
|
- // check common.UserUsableGroups[userGroup]
|
|
|
- if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
|
|
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
|
|
- return
|
|
|
- }
|
|
|
- // check group in common.GroupRatio
|
|
|
- if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
|
|
- if tokenGroup != "auto" {
|
|
|
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
|
|
- return
|
|
|
- }
|
|
|
- }
|
|
|
- userGroup = tokenGroup
|
|
|
- }
|
|
|
- common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
|
|
if ok {
|
|
|
id, err := strconv.Atoi(channelId.(string))
|
|
|
if err != nil {
|
|
|
@@ -81,20 +54,19 @@ func Distribute() func(c *gin.Context) {
|
|
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
|
|
if modelLimitEnable {
|
|
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
|
|
+ if !ok {
|
|
|
+ // token model limit is empty, all models are not allowed
|
|
|
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
|
|
+ return
|
|
|
+ }
|
|
|
var tokenModelLimit map[string]bool
|
|
|
- if ok {
|
|
|
- tokenModelLimit = s.(map[string]bool)
|
|
|
- } else {
|
|
|
+ tokenModelLimit, ok = s.(map[string]bool)
|
|
|
+ if !ok {
|
|
|
tokenModelLimit = map[string]bool{}
|
|
|
}
|
|
|
- if tokenModelLimit != nil {
|
|
|
- if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
|
|
- abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
|
|
- return
|
|
|
- }
|
|
|
- } else {
|
|
|
- // token model limit is empty, all models are not allowed
|
|
|
- abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
|
|
+ matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
|
|
|
+ if _, ok := tokenModelLimit[matchName]; !ok {
|
|
|
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
var selectGroup string
|
|
|
+ userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
|
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
|
|
if err != nil {
|
|
|
showGroup := userGroup
|