Browse Source

🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels

CaIon 5 months ago
parent
commit
f0f277dc2a

+ 19 - 5
common/str.go

@@ -31,16 +31,30 @@ func MapToJsonStr(m map[string]interface{}) string {
 	return string(bytes)
 }
 
-func StrToMap(str string) map[string]interface{} {
+func StrToMap(str string) (map[string]interface{}, error) {
 	m := make(map[string]interface{})
-	err := json.Unmarshal([]byte(str), &m)
+	err := UnmarshalJson([]byte(str), &m)
 	if err != nil {
-		return nil
+		return nil, err
 	}
-	return m
+	return m, nil
 }
 
-func IsJsonStr(str string) bool {
+func StrToJsonArray(str string) ([]interface{}, error) {
+	var js []interface{}
+	err := json.Unmarshal([]byte(str), &js)
+	if err != nil {
+		return nil, err
+	}
+	return js, nil
+}
+
+func IsJsonArray(str string) bool {
+	var js []interface{}
+	return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func IsJsonObject(str string) bool {
 	var js map[string]interface{}
 	return json.Unmarshal([]byte(str), &js) == nil
 }

+ 12 - 5
constant/context_key.go

@@ -17,11 +17,18 @@ const (
 	ContextKeyTokenModelLimit        ContextKey = "token_model_limit"
 
 	/* channel related keys */
-	ContextKeyBaseUrl        ContextKey = "base_url"
-	ContextKeyChannelType    ContextKey = "channel_type"
-	ContextKeyChannelId      ContextKey = "channel_id"
-	ContextKeyChannelSetting ContextKey = "channel_setting"
-	ContextKeyParamOverride  ContextKey = "param_override"
+	ContextKeyChannelId                ContextKey = "channel_id"
+	ContextKeyChannelName              ContextKey = "channel_name"
+	ContextKeyChannelCreateTime        ContextKey = "channel_create_name"
+	ContextKeyChannelBaseUrl           ContextKey = "base_url"
+	ContextKeyChannelType              ContextKey = "channel_type"
+	ContextKeyChannelSetting           ContextKey = "channel_setting"
+	ContextKeyChannelParamOverride     ContextKey = "param_override"
+	ContextKeyChannelOrganization      ContextKey = "channel_organization"
+	ContextKeyChannelAutoBan           ContextKey = "auto_ban"
+	ContextKeyChannelModelMapping      ContextKey = "model_mapping"
+	ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
+	ContextKeyChannelIsMultiKey        ContextKey = "channel_is_multi_key"
 
 	/* user related keys */
 	ContextKeyUserId      ContextKey = "id"

+ 8 - 0
constant/multi_key_mode.go

@@ -0,0 +1,8 @@
+package constant
+
+type MultiKeyMode string
+
+const (
+	MultiKeyModeRandom  MultiKeyMode = "random"  // 随机
+	MultiKeyModePolling MultiKeyMode = "polling" // 轮询
+)

+ 63 - 42
controller/channel.go

@@ -378,8 +378,31 @@ func GetChannel(c *gin.Context) {
 }
 
 type AddChannelRequest struct {
-	Mode    string         `json:"mode"`
-	Channel *model.Channel `json:"channel"`
+	Mode         string                `json:"mode"`
+	MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+	Channel      *model.Channel        `json:"channel"`
+}
+
+func getVertexArrayKeys(keys string) ([]string, error) {
+	if keys == "" {
+		return nil, nil
+	}
+	var keyArray []interface{}
+	err := common.UnmarshalJson([]byte(keys), &keyArray)
+	if err != nil {
+		return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
+	}
+	cleanKeys := make([]string, 0, len(keyArray))
+	for _, key := range keyArray {
+		keyStr := fmt.Sprintf("%v", key)
+		if keyStr != "" {
+			cleanKeys = append(cleanKeys, strings.TrimSpace(keyStr))
+		}
+	}
+	if len(cleanKeys) == 0 {
+		return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
+	}
+	return cleanKeys, nil
 }
 
 func AddChannel(c *gin.Context) {
@@ -418,16 +441,20 @@ func AddChannel(c *gin.Context) {
 			})
 			return
 		} else {
-			if common.IsJsonStr(addChannelRequest.Channel.Other) {
-				// must have default
-				regionMap := common.StrToMap(addChannelRequest.Channel.Other)
-				if regionMap["default"] == nil {
-					c.JSON(http.StatusOK, gin.H{
-						"success": false,
-						"message": "部署地区必须包含default字段",
-					})
-					return
-				}
+			regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
+				})
+				return
+			}
+			if regionMap["default"] == nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "部署地区必须包含default字段",
+				})
+				return
 			}
 		}
 	}
@@ -436,51 +463,41 @@ func AddChannel(c *gin.Context) {
 	keys := make([]string, 0)
 	switch addChannelRequest.Mode {
 	case "multi_to_single":
-		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true
+		addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
+		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
 		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
-			if !common.IsJsonStr(addChannelRequest.Channel.Key) {
+			array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
+			if err != nil {
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
-					"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
+					"message": err.Error(),
 				})
 				return
 			}
-			toMap := common.StrToMap(addChannelRequest.Channel.Key)
-			if toMap != nil {
-				addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap)
-			} else {
-				addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0
-			}
+			addChannelRequest.Channel.Key = strings.Join(array, "\n")
 		} else {
 			cleanKeys := make([]string, 0)
 			for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
 				if key == "" {
 					continue
 				}
+				key = strings.TrimSpace(key)
 				cleanKeys = append(cleanKeys, key)
 			}
-			addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
 			addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
 		}
 		keys = []string{addChannelRequest.Channel.Key}
 	case "batch":
 		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
 			// multi json
-			toMap := common.StrToMap(addChannelRequest.Channel.Key)
-			if toMap == nil {
+			keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
+			if err != nil {
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
-					"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
+					"message": err.Error(),
 				})
 				return
 			}
-			keys = make([]string, 0, len(toMap))
-			for k := range toMap {
-				if k == "" {
-					continue
-				}
-				keys = append(keys, k)
-			}
 		} else {
 			keys = strings.Split(addChannelRequest.Channel.Key, "\n")
 		}
@@ -694,16 +711,20 @@ func UpdateChannel(c *gin.Context) {
 			})
 			return
 		} else {
-			if common.IsJsonStr(channel.Other) {
-				// must have default
-				regionMap := common.StrToMap(channel.Other)
-				if regionMap["default"] == nil {
-					c.JSON(http.StatusOK, gin.H{
-						"success": false,
-						"message": "部署地区必须包含default字段",
-					})
-					return
-				}
+			regionMap, err := common.StrToMap(channel.Other)
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
+				})
+				return
+			}
+			if regionMap["default"] == nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "部署地区必须包含default字段",
+				})
+				return
 			}
 		}
 	}

+ 12 - 6
controller/playground.go

@@ -57,18 +57,24 @@ func Playground(c *gin.Context) {
 		}
 		c.Set("group", group)
 	}
-	c.Set("token_name", "playground-"+group)
-	channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
+
+	userId := c.GetInt("id")
+	//c.Set("token_name", "playground-"+group)
+	tempToken := &model.Token{
+		UserId: userId,
+		Name:   fmt.Sprintf("playground-%s", group),
+		Group:  group,
+	}
+	_ = middleware.SetupContextForToken(c, tempToken)
+	_, err = getChannel(c, group, playgroundRequest.Model, 0)
 	if err != nil {
-		message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
-		openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
+		openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError)
 		return
 	}
-	middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
+	//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
 	common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
 
 	// Write user context to ensure acceptUnsetRatio is available
-	userId := c.GetInt("id")
 	userCache, err := model.GetUserCache(userId)
 	if err != nil {
 		openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)

+ 8 - 4
controller/relay.go

@@ -259,9 +259,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 			AutoBan: &autoBanInt,
 		}, nil
 	}
-	channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
+	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	if err != nil {
-		return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
+		if group == "auto" {
+			return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
+		}
+		return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
 	}
 	middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 	return channel, nil
@@ -388,9 +391,10 @@ func RelayTask(c *gin.Context) {
 		retryTimes = 0
 	}
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
-		channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
+		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
 			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+			taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
 			break
 		}
 		channelId = channel.Id
@@ -398,7 +402,7 @@ func RelayTask(c *gin.Context) {
 		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 		c.Set("use_channel", useChannel)
 		common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
-		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
 		requestBody, err := common.GetRequestBody(c)
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

+ 34 - 22
middleware/auth.go

@@ -1,6 +1,7 @@
 package middleware
 
 import (
+	"fmt"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
@@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) {
 
 		userCache.WriteContext(c)
 
-		c.Set("id", token.UserId)
-		c.Set("token_id", token.Id)
-		c.Set("token_key", token.Key)
-		c.Set("token_name", token.Name)
-		c.Set("token_unlimited_quota", token.UnlimitedQuota)
-		if !token.UnlimitedQuota {
-			c.Set("token_quota", token.RemainQuota)
+		err = SetupContextForToken(c, token, parts...)
+		if err != nil {
+			return
 		}
-		if token.ModelLimitsEnabled {
-			c.Set("token_model_limit_enabled", true)
-			c.Set("token_model_limit", token.GetModelLimitsMap())
+		c.Next()
+	}
+}
+
+func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
+	if token == nil {
+		return fmt.Errorf("token is nil")
+	}
+	c.Set("id", token.UserId)
+	c.Set("token_id", token.Id)
+	c.Set("token_key", token.Key)
+	c.Set("token_name", token.Name)
+	c.Set("token_unlimited_quota", token.UnlimitedQuota)
+	if !token.UnlimitedQuota {
+		c.Set("token_quota", token.RemainQuota)
+	}
+	if token.ModelLimitsEnabled {
+		c.Set("token_model_limit_enabled", true)
+		c.Set("token_model_limit", token.GetModelLimitsMap())
+	} else {
+		c.Set("token_model_limit_enabled", false)
+	}
+	c.Set("allow_ips", token.GetIpLimitsMap())
+	c.Set("token_group", token.Group)
+	if len(parts) > 1 {
+		if model.IsAdmin(token.UserId) {
+			c.Set("specific_channel_id", parts[1])
 		} else {
-			c.Set("token_model_limit_enabled", false)
-		}
-		c.Set("allow_ips", token.GetIpLimitsMap())
-		c.Set("token_group", token.Group)
-		if len(parts) > 1 {
-			if model.IsAdmin(token.UserId) {
-				c.Set("specific_channel_id", parts[1])
-			} else {
-				abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
-				return
-			}
+			abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+			return fmt.Errorf("普通用户不支持指定渠道")
 		}
-		c.Next()
 	}
+	return nil
 }

+ 26 - 12
middleware/distributor.go

@@ -21,6 +21,7 @@ import (
 
 type ModelRequest struct {
 	Model string `json:"model"`
+	Group string `json:"group,omitempty"`
 }
 
 func Distribute() func(c *gin.Context) {
@@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("relay_mode", relayMode)
 	}
+	if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
+		// playground chat completions
+		err = common.UnmarshalBodyReusable(c, &modelRequest)
+		if err != nil {
+			return nil, false, errors.New("无效的请求, " + err.Error())
+		}
+		common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
+	}
 	return &modelRequest, shouldSelectChannel, nil
 }
 
@@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	if channel == nil {
 		return
 	}
-	c.Set("channel_id", channel.Id)
-	c.Set("channel_name", channel.Name)
-	c.Set("channel_type", channel.Type)
-	c.Set("channel_create_time", channel.CreatedTime)
-	c.Set("channel_setting", channel.GetSetting())
-	c.Set("param_override", channel.GetParamOverride())
-	if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
-		c.Set("channel_organization", *channel.OpenAIOrganization)
+	common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
+	common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
+	common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
+	common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
+	common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
+	common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
+	if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
+		common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
+	}
+	common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
+	common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
+	common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
+	if channel.ChannelInfo.IsMultiKey {
+		common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
+
 	}
-	c.Set("auto_ban", channel.GetAutoBan())
-	c.Set("model_mapping", channel.GetModelMapping())
-	c.Set("status_code_mapping", channel.GetStatusCodeMapping())
 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-	c.Set("base_url", channel.GetBaseURL())
+	common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
+
 	// TODO: api_version统一
 	switch channel.Type {
 	case constant.ChannelTypeAzure:

+ 82 - 6
model/channel.go

@@ -3,7 +3,10 @@ package model
 import (
 	"database/sql/driver"
 	"encoding/json"
+	"fmt"
+	"math/rand"
 	"one-api/common"
+	"one-api/constant"
 	"strings"
 	"sync"
 
@@ -43,20 +46,93 @@ type Channel struct {
 }
 
 type ChannelInfo struct {
-	MultiKeyMode       bool        `json:"multi_key_mode"`        // 是否多Key模式
-	MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
-	MultiKeySize       int         `json:"multi_key_size"`        // 多Key模式下的key数量
+	IsMultiKey           bool                  `json:"is_multi_key"`            // 是否多Key模式
+	MultiKeyStatusList   map[int]int           `json:"multi_key_status_list"`   // key状态列表,key index -> status
+	MultiKeyPollingIndex int                   `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
+	MultiKeyMode         constant.MultiKeyMode `json:"multi_key_mode"`
 }
 
 // Value implements driver.Valuer interface
-func (c ChannelInfo) Value() (driver.Value, error) {
-	return json.Marshal(c)
+func (c *ChannelInfo) Value() (driver.Value, error) {
+	return common.EncodeJson(c)
 }
 
 // Scan implements sql.Scanner interface
 func (c *ChannelInfo) Scan(value interface{}) error {
 	bytesValue, _ := value.([]byte)
-	return json.Unmarshal(bytesValue, c)
+	return common.UnmarshalJson(bytesValue, c)
+}
+
+func (channel *Channel) getKeys() []string {
+	if channel.Key == "" {
+		return []string{}
+	}
+	// use \n to split keys
+	keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
+	return keys
+}
+
+func (channel *Channel) GetNextEnabledKey() (string, error) {
+	// If not in multi-key mode, return the original key string directly.
+	if !channel.ChannelInfo.IsMultiKey {
+		return channel.Key, nil
+	}
+
+	// Obtain all keys (split by \n)
+	keys := channel.getKeys()
+	if len(keys) == 0 {
+		// No keys available, return error, should disable the channel
+		return "", fmt.Errorf("no valid keys in channel")
+	}
+
+	statusList := channel.ChannelInfo.MultiKeyStatusList
+	// helper to get key status, default to enabled when missing
+	getStatus := func(idx int) int {
+		if statusList == nil {
+			return common.ChannelStatusEnabled
+		}
+		if status, ok := statusList[idx]; ok {
+			return status
+		}
+		return common.ChannelStatusEnabled
+	}
+
+	// Collect indexes of enabled keys
+	enabledIdx := make([]int, 0, len(keys))
+	for i := range keys {
+		if getStatus(i) == common.ChannelStatusEnabled {
+			enabledIdx = append(enabledIdx, i)
+		}
+	}
+	// If no specific status list or none enabled, fall back to first key
+	if len(enabledIdx) == 0 {
+		return keys[0], nil
+	}
+
+	switch channel.ChannelInfo.MultiKeyMode {
+	case constant.MultiKeyModeRandom:
+		// Randomly pick one enabled key
+		return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
+	case constant.MultiKeyModePolling:
+		// Start from the saved polling index and look for the next enabled key
+		start := channel.ChannelInfo.MultiKeyPollingIndex
+		if start < 0 || start >= len(keys) {
+			start = 0
+		}
+		for i := 0; i < len(keys); i++ {
+			idx := (start + i) % len(keys)
+			if getStatus(idx) == common.ChannelStatusEnabled {
+				// update polling index for next call (point to the next position)
+				channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
+				return keys[idx], nil
+			}
+		}
+		// Fallback – should not happen, but return first enabled key
+		return keys[enabledIdx[0]], nil
+	default:
+		// Unknown mode, default to first enabled key (or original key string)
+		return keys[enabledIdx[0]], nil
+	}
 }
 
 func (channel *Channel) GetModels() []string {

+ 1 - 1
model/log.go

@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
 	for i := range logs {
 		logs[i].ChannelName = ""
 		var otherMap map[string]interface{}
-		otherMap = common.StrToMap(logs[i].Other)
+		otherMap, _ = common.StrToMap(logs[i].Other)
 		if otherMap != nil {
 			// delete admin
 			delete(otherMap, "admin_info")

+ 14 - 4
model/user.go

@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
 	user.AccessToken = &token
 }
 
-func (user *User) GetSetting() map[string]interface{} {
+func (user *User) GetSetting() (map[string]interface{}, error) {
 	if user.Setting == "" {
-		return nil
+		return map[string]interface{}{}, nil
+	}
+	toMap, err := common.StrToMap(user.Setting)
+	if err != nil {
+		common.SysError("failed to convert setting to map: " + err.Error())
+		return nil, fmt.Errorf("failed to convert setting to map")
 	}
-	return common.StrToMap(user.Setting)
+	return toMap, nil
 }
 
 func (user *User) SetSetting(setting map[string]interface{}) {
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
 		return map[string]interface{}{}, err
 	}
 
-	return common.StrToMap(setting), nil
+	toMap, err := common.StrToMap(setting)
+	if err != nil {
+		common.SysError("failed to convert setting to map: " + err.Error())
+		return nil, fmt.Errorf("failed to convert setting to map")
+	}
+	return toMap, nil
 }
 
 func IncreaseUserQuota(id int, quota int, db bool) (err error) {

+ 6 - 1
model/user_cache.go

@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
 	if user.Setting == "" {
 		return nil
 	}
-	return common.StrToMap(user.Setting)
+	toMap, err := common.StrToMap(user.Setting)
+	if err != nil {
+		common.SysError("failed to convert user setting to map: " + err.Error())
+		return nil
+	}
+	return toMap
 }
 
 func (user *UserBase) SetSetting(setting map[string]interface{}) {

+ 5 - 2
relay/channel/vertex/relay-vertex.go

@@ -4,8 +4,11 @@ import "one-api/common"
 
 func GetModelRegion(other string, localModelName string) string {
 	// if other is json string
-	if common.IsJsonStr(other) {
-		m := common.StrToMap(other)
+	if common.IsJsonObject(other) {
+		m, err := common.StrToMap(other)
+		if err != nil {
+			return other // return original if parsing fails
+		}
 		if m[localModelName] != nil {
 			return m[localModelName].(string)
 		} else {

+ 2 - 2
relay/common/relay_info.go

@@ -214,7 +214,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
 	channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
 	channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
-	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
+	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
 
 	tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
 	tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
@@ -231,7 +231,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		UserEmail:         common.GetContextKeyString(c, constant.ContextKeyUserEmail),
 		isFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
-		BaseUrl:           common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
+		BaseUrl:           common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
 		RequestURLPath:    c.Request.URL.String(),
 		ChannelType:       channelType,
 		ChannelId:         channelId,

+ 1 - 1
router/relay-router.go

@@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
 		modelsRouter.GET("/:model", controller.RetrieveModel)
 	}
 	playgroundRouter := router.Group("/pg")
-	playgroundRouter.Use(middleware.UserAuth())
+	playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
 	{
 		playgroundRouter.POST("/chat/completions", controller.Playground)
 	}