Explorar el Código

refactor: 调整模型匹配

Xyfacai hace 5 meses
padre
commit
0c0caad827

+ 0 - 1
constant/context_key.go

@@ -11,7 +11,6 @@ const (
 	ContextKeyTokenKey               ContextKey = "token_key"
 	ContextKeyTokenId                ContextKey = "token_id"
 	ContextKeyTokenGroup             ContextKey = "token_group"
-	ContextKeyTokenAllowIps          ContextKey = "allow_ips"
 	ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
 	ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
 	ContextKeyTokenModelLimit        ContextKey = "token_model_limit"

+ 32 - 1
middleware/auth.go

@@ -4,7 +4,10 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/ratio_setting"
 	"strconv"
 	"strings"
 
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
 			abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
 			return
 		}
+
+		allowIpsMap := token.GetIpLimitsMap()
+		if len(allowIpsMap) != 0 {
+			clientIp := c.ClientIP()
+			if _, ok := allowIpsMap[clientIp]; !ok {
+				abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
+				return
+			}
+		}
+
 		userCache, err := model.GetUserCache(token.UserId)
 		if err != nil {
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
 
 		userCache.WriteContext(c)
 
+		userGroup := userCache.Group
+		tokenGroup := token.Group
+		if tokenGroup != "" {
+			// check common.UserUsableGroups[userGroup]
+			if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
+				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
+				return
+			}
+			// check group in common.GroupRatio
+			if !ratio_setting.ContainsGroupRatio(tokenGroup) {
+				if tokenGroup != "auto" {
+					abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+					return
+				}
+			}
+			userGroup = tokenGroup
+		}
+		common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
+
 		err = SetupContextForToken(c, token, parts...)
 		if err != nil {
 			return
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
 	} else {
 		c.Set("token_model_limit_enabled", false)
 	}
-	c.Set("allow_ips", token.GetIpLimitsMap())
 	c.Set("token_group", token.Group)
 	if len(parts) > 1 {
 		if model.IsAdmin(token.UserId) {

+ 11 - 38
middleware/distributor.go

@@ -10,7 +10,6 @@ import (
 	"one-api/model"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/setting/ratio_setting"
 	"one-api/types"
 	"strconv"
@@ -27,14 +26,6 @@ type ModelRequest struct {
 
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
-		allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
-		if len(allowIpsMap) != 0 {
-			clientIp := c.ClientIP()
-			if _, ok := allowIpsMap[clientIp]; !ok {
-				abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
-				return
-			}
-		}
 		var channel *model.Channel
 		channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
 		modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
 			abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
 			return
 		}
-		userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
-		tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
-		if tokenGroup != "" {
-			// check common.UserUsableGroups[userGroup]
-			if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
-				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
-				return
-			}
-			// check group in common.GroupRatio
-			if !ratio_setting.ContainsGroupRatio(tokenGroup) {
-				if tokenGroup != "auto" {
-					abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
-					return
-				}
-			}
-			userGroup = tokenGroup
-		}
-		common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
@@ -81,20 +54,19 @@ func Distribute() func(c *gin.Context) {
 			modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
 			if modelLimitEnable {
 				s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
+				if !ok {
+					// token model limit is empty, all models are not allowed
+					abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+					return
+				}
 				var tokenModelLimit map[string]bool
-				if ok {
-					tokenModelLimit = s.(map[string]bool)
-				} else {
+				tokenModelLimit, ok = s.(map[string]bool)
+				if !ok {
 					tokenModelLimit = map[string]bool{}
 				}
-				if tokenModelLimit != nil {
-					if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
-						abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
-						return
-					}
-				} else {
-					// token model limit is empty, all models are not allowed
-					abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+				matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
+				if _, ok := tokenModelLimit[matchName]; !ok {
+					abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
 					return
 				}
 			}
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
 					return
 				}
 				var selectGroup string
+				userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
 				channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
 				if err != nil {
 					showGroup := userGroup

+ 2 - 6
model/channel_cache.go

@@ -7,6 +7,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/setting"
+	"one-api/setting/ratio_setting"
 	"sort"
 	"strings"
 	"sync"
@@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
 }
 
 func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
-	if strings.HasPrefix(model, "gpt-4-gizmo") {
-		model = "gpt-4-gizmo-*"
-	}
-	if strings.HasPrefix(model, "gpt-4o-gizmo") {
-		model = "gpt-4o-gizmo-*"
-	}
+	model = ratio_setting.FormatMatchingModelName(model)
 
 	// if memory cache is disabled, get channel directly from database
 	if !common.MemoryCacheEnabled {

+ 20 - 17
setting/ratio_setting/model_ratio.go

@@ -335,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 	modelPriceMapMutex.RLock()
 	defer modelPriceMapMutex.RUnlock()
 
-	if strings.HasPrefix(name, "gpt-4-gizmo") {
-		name = "gpt-4-gizmo-*"
-	}
-	if strings.HasPrefix(name, "gpt-4o-gizmo") {
-		name = "gpt-4o-gizmo-*"
-	}
+	name = FormatMatchingModelName(name)
+
 	price, ok := modelPriceMap[name]
 	if !ok {
 		if printErr {
@@ -374,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) {
 	modelRatioMapMutex.RLock()
 	defer modelRatioMapMutex.RUnlock()
 
-	name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
-	name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
-	if strings.HasPrefix(name, "gpt-4-gizmo") {
-		name = "gpt-4-gizmo-*"
-	}
+	name = FormatMatchingModelName(name)
+
 	ratio, ok := modelRatioMap[name]
 	if !ok {
 		return 37.5, operation_setting.SelfUseModeEnabled, name
@@ -429,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
 func GetCompletionRatio(name string) float64 {
 	CompletionRatioMutex.RLock()
 	defer CompletionRatioMutex.RUnlock()
-	if strings.HasPrefix(name, "gpt-4-gizmo") {
-		name = "gpt-4-gizmo-*"
-	}
-	if strings.HasPrefix(name, "gpt-4o-gizmo") {
-		name = "gpt-4o-gizmo-*"
-	}
+
+	name = FormatMatchingModelName(name)
+
 	if strings.Contains(name, "/") {
 		if ratio, ok := CompletionRatio[name]; ok {
 			return ratio
@@ -664,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 {
 	}
 	return copyMap
 }
+
+// 转换模型名,减少渠道必须配置各种带参数模型
+func FormatMatchingModelName(name string) string {
+	name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
+	name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
+	if strings.HasPrefix(name, "gpt-4-gizmo") {
+		name = "gpt-4-gizmo-*"
+	}
+	if strings.HasPrefix(name, "gpt-4o-gizmo") {
+		name = "gpt-4o-gizmo-*"
+	}
+	return name
+}