Procházet zdrojové kódy

✨ feat: refactor environment variable initialization and introduce new constant types for API and context keys

CaIon před 6 měsíci
rodič
revize
7e298f8ad1

+ 71 - 0
common/api_type.go

@@ -0,0 +1,71 @@
+package common
+
+import "one-api/constant"
+
+func ChannelType2APIType(channelType int) (int, bool) {
+	apiType := -1
+	switch channelType {
+	case constant.ChannelTypeOpenAI:
+		apiType = constant.APITypeOpenAI
+	case constant.ChannelTypeAnthropic:
+		apiType = constant.APITypeAnthropic
+	case constant.ChannelTypeBaidu:
+		apiType = constant.APITypeBaidu
+	case constant.ChannelTypePaLM:
+		apiType = constant.APITypePaLM
+	case constant.ChannelTypeZhipu:
+		apiType = constant.APITypeZhipu
+	case constant.ChannelTypeAli:
+		apiType = constant.APITypeAli
+	case constant.ChannelTypeXunfei:
+		apiType = constant.APITypeXunfei
+	case constant.ChannelTypeAIProxyLibrary:
+		apiType = constant.APITypeAIProxyLibrary
+	case constant.ChannelTypeTencent:
+		apiType = constant.APITypeTencent
+	case constant.ChannelTypeGemini:
+		apiType = constant.APITypeGemini
+	case constant.ChannelTypeZhipu_v4:
+		apiType = constant.APITypeZhipuV4
+	case constant.ChannelTypeOllama:
+		apiType = constant.APITypeOllama
+	case constant.ChannelTypePerplexity:
+		apiType = constant.APITypePerplexity
+	case constant.ChannelTypeAws:
+		apiType = constant.APITypeAws
+	case constant.ChannelTypeCohere:
+		apiType = constant.APITypeCohere
+	case constant.ChannelTypeDify:
+		apiType = constant.APITypeDify
+	case constant.ChannelTypeJina:
+		apiType = constant.APITypeJina
+	case constant.ChannelCloudflare:
+		apiType = constant.APITypeCloudflare
+	case constant.ChannelTypeSiliconFlow:
+		apiType = constant.APITypeSiliconFlow
+	case constant.ChannelTypeVertexAi:
+		apiType = constant.APITypeVertexAi
+	case constant.ChannelTypeMistral:
+		apiType = constant.APITypeMistral
+	case constant.ChannelTypeDeepSeek:
+		apiType = constant.APITypeDeepSeek
+	case constant.ChannelTypeMokaAI:
+		apiType = constant.APITypeMokaAI
+	case constant.ChannelTypeVolcEngine:
+		apiType = constant.APITypeVolcEngine
+	case constant.ChannelTypeBaiduV2:
+		apiType = constant.APITypeBaiduV2
+	case constant.ChannelTypeOpenRouter:
+		apiType = constant.APITypeOpenRouter
+	case constant.ChannelTypeXinference:
+		apiType = constant.APITypeXinference
+	case constant.ChannelTypeXai:
+		apiType = constant.APITypeXai
+	case constant.ChannelTypeCoze:
+		apiType = constant.APITypeCoze
+	}
+	if apiType == -1 {
+		return constant.APITypeOpenAI, false
+	}
+	return apiType, true
+}

+ 0 - 108
common/constants.go

@@ -193,111 +193,3 @@ const (
 	ChannelStatusManuallyDisabled = 2 // also don't use 0
 	ChannelStatusAutoDisabled     = 3
 )
-
-const (
-	ChannelTypeUnknown        = 0
-	ChannelTypeOpenAI         = 1
-	ChannelTypeMidjourney     = 2
-	ChannelTypeAzure          = 3
-	ChannelTypeOllama         = 4
-	ChannelTypeMidjourneyPlus = 5
-	ChannelTypeOpenAIMax      = 6
-	ChannelTypeOhMyGPT        = 7
-	ChannelTypeCustom         = 8
-	ChannelTypeAILS           = 9
-	ChannelTypeAIProxy        = 10
-	ChannelTypePaLM           = 11
-	ChannelTypeAPI2GPT        = 12
-	ChannelTypeAIGC2D         = 13
-	ChannelTypeAnthropic      = 14
-	ChannelTypeBaidu          = 15
-	ChannelTypeZhipu          = 16
-	ChannelTypeAli            = 17
-	ChannelTypeXunfei         = 18
-	ChannelType360            = 19
-	ChannelTypeOpenRouter     = 20
-	ChannelTypeAIProxyLibrary = 21
-	ChannelTypeFastGPT        = 22
-	ChannelTypeTencent        = 23
-	ChannelTypeGemini         = 24
-	ChannelTypeMoonshot       = 25
-	ChannelTypeZhipu_v4       = 26
-	ChannelTypePerplexity     = 27
-	ChannelTypeLingYiWanWu    = 31
-	ChannelTypeAws            = 33
-	ChannelTypeCohere         = 34
-	ChannelTypeMiniMax        = 35
-	ChannelTypeSunoAPI        = 36
-	ChannelTypeDify           = 37
-	ChannelTypeJina           = 38
-	ChannelCloudflare         = 39
-	ChannelTypeSiliconFlow    = 40
-	ChannelTypeVertexAi       = 41
-	ChannelTypeMistral        = 42
-	ChannelTypeDeepSeek       = 43
-	ChannelTypeMokaAI         = 44
-	ChannelTypeVolcEngine     = 45
-	ChannelTypeBaiduV2        = 46
-	ChannelTypeXinference     = 47
-	ChannelTypeXai            = 48
-	ChannelTypeCoze           = 49
-	ChannelTypeKling          = 50
-	ChannelTypeJimeng         = 51
-	ChannelTypeDummy          // this one is only for count, do not add any channel after this
-
-)
-
-var ChannelBaseURLs = []string{
-	"",                                    // 0
-	"https://api.openai.com",              // 1
-	"https://oa.api2d.net",                // 2
-	"",                                    // 3
-	"http://localhost:11434",              // 4
-	"https://api.openai-sb.com",           // 5
-	"https://api.openaimax.com",           // 6
-	"https://api.ohmygpt.com",             // 7
-	"",                                    // 8
-	"https://api.caipacity.com",           // 9
-	"https://api.aiproxy.io",              // 10
-	"",                                    // 11
-	"https://api.api2gpt.com",             // 12
-	"https://api.aigc2d.com",              // 13
-	"https://api.anthropic.com",           // 14
-	"https://aip.baidubce.com",            // 15
-	"https://open.bigmodel.cn",            // 16
-	"https://dashscope.aliyuncs.com",      // 17
-	"",                                    // 18
-	"https://api.360.cn",                  // 19
-	"https://openrouter.ai/api",           // 20
-	"https://api.aiproxy.io",              // 21
-	"https://fastgpt.run/api/openapi",     // 22
-	"https://hunyuan.tencentcloudapi.com", //23
-	"https://generativelanguage.googleapis.com", //24
-	"https://api.moonshot.cn",                   //25
-	"https://open.bigmodel.cn",                  //26
-	"https://api.perplexity.ai",                 //27
-	"",                                          //28
-	"",                                          //29
-	"",                                          //30
-	"https://api.lingyiwanwu.com",               //31
-	"",                                          //32
-	"",                                          //33
-	"https://api.cohere.ai",                     //34
-	"https://api.minimax.chat",                  //35
-	"",                                          //36
-	"https://api.dify.ai",                       //37
-	"https://api.jina.ai",                       //38
-	"https://api.cloudflare.com",                //39
-	"https://api.siliconflow.cn",                //40
-	"",                                          //41
-	"https://api.mistral.ai",                    //42
-	"https://api.deepseek.com",                  //43
-	"https://api.moka.ai",                       //44
-	"https://ark.cn-beijing.volces.com",         //45
-	"https://qianfan.baidubce.com",              //46
-	"",                                          //47
-	"https://api.x.ai",                          //48
-	"https://api.coze.cn",                       //49
-	"https://api.klingai.com",                   //50
-	"https://visual.volcengineapi.com",          //51
-}

+ 29 - 0
common/endpoint_type.go

@@ -0,0 +1,29 @@
+package common
+
+import "one-api/constant"
+
+// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
+func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
+	var endpointTypes []constant.EndpointType
+	switch channelType {
+	case constant.ChannelTypeJina:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
+	case constant.ChannelTypeAws:
+		fallthrough
+	case constant.ChannelTypeAnthropic:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
+	case constant.ChannelTypeVertexAi:
+		fallthrough
+	case constant.ChannelTypeGemini:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
+	case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+	default:
+		if IsOpenAIResponseOnlyModel(modelName) {
+			endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
+		} else {
+			endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+		}
+	}
+	return endpointTypes
+}

+ 34 - 0
common/gin.go

@@ -4,7 +4,9 @@ import (
 	"bytes"
 	"github.com/gin-gonic/gin"
 	"io"
+	"one-api/constant"
 	"strings"
+	"time"
 )
 
 const KeyRequestBody = "key_request_body"
@@ -42,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	return nil
 }
+
+func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
+	c.Set(string(key), value)
+}
+
+func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
+	return c.Get(string(key))
+}
+
+func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
+	return c.GetString(string(key))
+}
+
+func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
+	return c.GetInt(string(key))
+}
+
+func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
+	return c.GetBool(string(key))
+}
+
+func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
+	return c.GetStringSlice(string(key))
+}
+
+func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
+	return c.GetStringMap(string(key))
+}
+
+func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
+	return c.GetTime(string(key))
+}

+ 23 - 1
common/init.go

@@ -4,6 +4,7 @@ import (
 	"flag"
 	"fmt"
 	"log"
+	"one-api/constant"
 	"os"
 	"path/filepath"
 	"strconv"
@@ -24,7 +25,7 @@ func printHelp() {
 	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
 }
 
-func InitCommonEnv() {
+func InitEnv() {
 	flag.Parse()
 
 	if *PrintVersion {
@@ -95,4 +96,25 @@ func InitCommonEnv() {
 	GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
 	GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 	GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
+
+	initConstantEnv()
+}
+
+func initConstantEnv() {
+	constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
+	constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
+	constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
+	// ForceStreamOption 覆盖请求参数,强制返回usage信息
+	constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
+	constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
+	constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
+	constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
+	constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
+	constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+	constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+	constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+	// GenerateDefaultToken 是否生成初始令牌,默认关闭。
+	constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+	// 是否启用错误日志
+	constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
 }

+ 21 - 0
common/model.go

@@ -0,0 +1,21 @@
+package common
+
+import "strings"
+
+var (
+	// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
+	OpenAIResponseOnlyModels = []string{
+		"o3-pro",
+		"o3-deep-research",
+		"o4-mini-deep-research",
+	}
+)
+
+func IsOpenAIResponseOnlyModel(modelName string) bool {
+	for _, m := range OpenAIResponseOnlyModels {
+		if strings.Contains(m, modelName) {
+			return true
+		}
+	}
+	return false
+}

+ 4 - 0
common/redis.go

@@ -16,6 +16,10 @@ import (
 var RDB *redis.Client
 var RedisEnabled = true
 
+func RedisKeyCacheSeconds() int {
+	return SyncFrequency
+}
+
 // InitRedisClient This function is called after init()
 func InitRedisClient() (err error) {
 	if os.Getenv("REDIS_CONN_STRING") == "" {

+ 26 - 0
constant/README.md

@@ -0,0 +1,26 @@
+# constant 包 (`/constant`)
+
+该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
+
+## 当前文件
+
+| 文件                   | 说明                                                                  |
+|----------------------|---------------------------------------------------------------------|
+| `azure.go`           | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。         |
+| `cache_key.go`       | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。                                    |
+| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。                          |
+| `context_key.go`     | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
+| `env.go`             | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。                                     |
+| `finish_reason.go`   | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。                           |
+| `midjourney.go`      | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。                            |
+| `setup.go`           | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。                                       |
+| `task.go`            | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。                     |
+| `user_setting.go`    | 用户设置相关键常量以及通知类型(Email/Webhook)等。                                    |
+
+## 使用约定
+
+1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
+2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
+3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
+
+> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。

+ 34 - 0
constant/api_type.go

@@ -0,0 +1,34 @@
+package constant
+
+const (
+	APITypeOpenAI = iota
+	APITypeAnthropic
+	APITypePaLM
+	APITypeBaidu
+	APITypeZhipu
+	APITypeAli
+	APITypeXunfei
+	APITypeAIProxyLibrary
+	APITypeTencent
+	APITypeGemini
+	APITypeZhipuV4
+	APITypeOllama
+	APITypePerplexity
+	APITypeAws
+	APITypeCohere
+	APITypeDify
+	APITypeJina
+	APITypeCloudflare
+	APITypeSiliconFlow
+	APITypeVertexAi
+	APITypeMistral
+	APITypeDeepSeek
+	APITypeMokaAI
+	APITypeVolcEngine
+	APITypeBaiduV2
+	APITypeOpenRouter
+	APITypeXinference
+	APITypeXai
+	APITypeCoze
+	APITypeDummy // this one is only for count, do not add any channel after this
+)

+ 0 - 7
constant/cache_key.go

@@ -1,12 +1,5 @@
 package constant
 
-import "one-api/common"
-
-// 使用函数来避免初始化顺序带来的赋值问题
-func RedisKeyCacheSeconds() int {
-	return common.SyncFrequency
-}
-
 // Cache keys
 const (
 	UserGroupKeyFmt    = "user_group:%d"

+ 109 - 0
constant/channel.go

@@ -0,0 +1,109 @@
+package constant
+
+const (
+	ChannelTypeUnknown        = 0
+	ChannelTypeOpenAI         = 1
+	ChannelTypeMidjourney     = 2
+	ChannelTypeAzure          = 3
+	ChannelTypeOllama         = 4
+	ChannelTypeMidjourneyPlus = 5
+	ChannelTypeOpenAIMax      = 6
+	ChannelTypeOhMyGPT        = 7
+	ChannelTypeCustom         = 8
+	ChannelTypeAILS           = 9
+	ChannelTypeAIProxy        = 10
+	ChannelTypePaLM           = 11
+	ChannelTypeAPI2GPT        = 12
+	ChannelTypeAIGC2D         = 13
+	ChannelTypeAnthropic      = 14
+	ChannelTypeBaidu          = 15
+	ChannelTypeZhipu          = 16
+	ChannelTypeAli            = 17
+	ChannelTypeXunfei         = 18
+	ChannelType360            = 19
+	ChannelTypeOpenRouter     = 20
+	ChannelTypeAIProxyLibrary = 21
+	ChannelTypeFastGPT        = 22
+	ChannelTypeTencent        = 23
+	ChannelTypeGemini         = 24
+	ChannelTypeMoonshot       = 25
+	ChannelTypeZhipu_v4       = 26
+	ChannelTypePerplexity     = 27
+	ChannelTypeLingYiWanWu    = 31
+	ChannelTypeAws            = 33
+	ChannelTypeCohere         = 34
+	ChannelTypeMiniMax        = 35
+	ChannelTypeSunoAPI        = 36
+	ChannelTypeDify           = 37
+	ChannelTypeJina           = 38
+	ChannelCloudflare         = 39
+	ChannelTypeSiliconFlow    = 40
+	ChannelTypeVertexAi       = 41
+	ChannelTypeMistral        = 42
+	ChannelTypeDeepSeek       = 43
+	ChannelTypeMokaAI         = 44
+	ChannelTypeVolcEngine     = 45
+	ChannelTypeBaiduV2        = 46
+	ChannelTypeXinference     = 47
+	ChannelTypeXai            = 48
+	ChannelTypeCoze           = 49
+	ChannelTypeKling          = 50
+	ChannelTypeJimeng         = 51
+	ChannelTypeDummy          // this one is only for count, do not add any channel after this
+
+)
+
+var ChannelBaseURLs = []string{
+	"",                                    // 0
+	"https://api.openai.com",              // 1
+	"https://oa.api2d.net",                // 2
+	"",                                    // 3
+	"http://localhost:11434",              // 4
+	"https://api.openai-sb.com",           // 5
+	"https://api.openaimax.com",           // 6
+	"https://api.ohmygpt.com",             // 7
+	"",                                    // 8
+	"https://api.caipacity.com",           // 9
+	"https://api.aiproxy.io",              // 10
+	"",                                    // 11
+	"https://api.api2gpt.com",             // 12
+	"https://api.aigc2d.com",              // 13
+	"https://api.anthropic.com",           // 14
+	"https://aip.baidubce.com",            // 15
+	"https://open.bigmodel.cn",            // 16
+	"https://dashscope.aliyuncs.com",      // 17
+	"",                                    // 18
+	"https://api.360.cn",                  // 19
+	"https://openrouter.ai/api",           // 20
+	"https://api.aiproxy.io",              // 21
+	"https://fastgpt.run/api/openapi",     // 22
+	"https://hunyuan.tencentcloudapi.com", //23
+	"https://generativelanguage.googleapis.com", //24
+	"https://api.moonshot.cn",                   //25
+	"https://open.bigmodel.cn",                  //26
+	"https://api.perplexity.ai",                 //27
+	"",                                          //28
+	"",                                          //29
+	"",                                          //30
+	"https://api.lingyiwanwu.com",               //31
+	"",                                          //32
+	"",                                          //33
+	"https://api.cohere.ai",                     //34
+	"https://api.minimax.chat",                  //35
+	"",                                          //36
+	"https://api.dify.ai",                       //37
+	"https://api.jina.ai",                       //38
+	"https://api.cloudflare.com",                //39
+	"https://api.siliconflow.cn",                //40
+	"",                                          //41
+	"https://api.mistral.ai",                    //42
+	"https://api.deepseek.com",                  //43
+	"https://api.moka.ai",                       //44
+	"https://ark.cn-beijing.volces.com",         //45
+	"https://qianfan.baidubce.com",              //46
+	"",                                          //47
+	"https://api.x.ai",                          //48
+	"https://api.coze.cn",                       //49
+	"https://api.klingai.com",                   //50
+	"https://visual.volcengineapi.com",          //51
+}

+ 31 - 7
constant/context_key.go

@@ -1,11 +1,35 @@
 package constant
 
+type ContextKey string
+
 const (
-	ContextKeyRequestStartTime = "request_start_time"
-	ContextKeyUserSetting      = "user_setting"
-	ContextKeyUserQuota        = "user_quota"
-	ContextKeyUserStatus       = "user_status"
-	ContextKeyUserEmail        = "user_email"
-	ContextKeyUserGroup        = "user_group"
-	ContextKeyUsingGroup       = "group"
+	ContextKeyOriginalModel    ContextKey = "original_model"
+	ContextKeyRequestStartTime ContextKey = "request_start_time"
+
+	/* token related keys */
+	ContextKeyTokenUnlimited         ContextKey = "token_unlimited_quota"
+	ContextKeyTokenKey               ContextKey = "token_key"
+	ContextKeyTokenId                ContextKey = "token_id"
+	ContextKeyTokenGroup             ContextKey = "token_group"
+	ContextKeyTokenAllowIps          ContextKey = "allow_ips"
+	ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
+	ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
+	ContextKeyTokenModelLimit        ContextKey = "token_model_limit"
+
+	/* channel related keys */
+	ContextKeyBaseUrl        ContextKey = "base_url"
+	ContextKeyChannelType    ContextKey = "channel_type"
+	ContextKeyChannelId      ContextKey = "channel_id"
+	ContextKeyChannelSetting ContextKey = "channel_setting"
+	ContextKeyParamOverride  ContextKey = "param_override"
+
+	/* user related keys */
+	ContextKeyUserId      ContextKey = "id"
+	ContextKeyUserSetting ContextKey = "user_setting"
+	ContextKeyUserQuota   ContextKey = "user_quota"
+	ContextKeyUserStatus  ContextKey = "user_status"
+	ContextKeyUserEmail   ContextKey = "user_email"
+	ContextKeyUserGroup   ContextKey = "user_group"
+	ContextKeyUsingGroup  ContextKey = "group"
+	ContextKeyUserName    ContextKey = "username"
 )

+ 11 - 0
constant/endpoint_type.go

@@ -0,0 +1,11 @@
+package constant
+
+type EndpointType string
+
+const (
+	EndpointTypeOpenAI         EndpointType = "openai"
+	EndpointTypeOpenAIResponse EndpointType = "openai-response"
+	EndpointTypeAnthropic      EndpointType = "anthropic"
+	EndpointTypeGemini         EndpointType = "gemini"
+	EndpointTypeJinaRerank     EndpointType = "jina-rerank"
+)

+ 0 - 40
constant/env.go

@@ -1,9 +1,5 @@
 package constant
 
-import (
-	"one-api/common"
-)
-
 var StreamingTimeout int
 var DifyDebug bool
 var MaxFileDownloadMB int
@@ -17,39 +13,3 @@ var NotifyLimitCount int
 var NotificationLimitDurationMinute int
 var GenerateDefaultToken bool
 var ErrorLogEnabled bool
-
-//var GeminiModelMap = map[string]string{
-//	"gemini-1.0-pro": "v1",
-//}
-
-func InitEnv() {
-	StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120)
-	DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
-	MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
-	// ForceStreamOption 覆盖请求参数,强制返回usage信息
-	ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
-	GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
-	GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
-	UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
-	AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
-	GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
-	NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
-	NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
-	// GenerateDefaultToken 是否生成初始令牌,默认关闭。
-	GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
-	// 是否启用错误日志
-	ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
-
-	//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
-	//if modelVersionMapStr == "" {
-	//	return
-	//}
-	//for _, pair := range strings.Split(modelVersionMapStr, ",") {
-	//	parts := strings.Split(pair, ":")
-	//	if len(parts) == 2 {
-	//		GeminiModelMap[parts[0]] = parts[1]
-	//	} else {
-	//		common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
-	//	}
-	//}
-}

+ 12 - 11
controller/channel-billing.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/model"
 	"one-api/service"
 	"one-api/setting"
@@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
 }
 
 func updateChannelBalance(channel *model.Channel) (float64, error) {
-	baseURL := common.ChannelBaseURLs[channel.Type]
+	baseURL := constant.ChannelBaseURLs[channel.Type]
 	if channel.GetBaseURL() == "" {
 		channel.BaseURL = &baseURL
 	}
 	switch channel.Type {
-	case common.ChannelTypeOpenAI:
+	case constant.ChannelTypeOpenAI:
 		if channel.GetBaseURL() != "" {
 			baseURL = channel.GetBaseURL()
 		}
-	case common.ChannelTypeAzure:
+	case constant.ChannelTypeAzure:
 		return 0, errors.New("尚未实现")
-	case common.ChannelTypeCustom:
+	case constant.ChannelTypeCustom:
 		baseURL = channel.GetBaseURL()
 	//case common.ChannelTypeOpenAISB:
 	//	return updateChannelOpenAISBBalance(channel)
-	case common.ChannelTypeAIProxy:
+	case constant.ChannelTypeAIProxy:
 		return updateChannelAIProxyBalance(channel)
-	case common.ChannelTypeAPI2GPT:
+	case constant.ChannelTypeAPI2GPT:
 		return updateChannelAPI2GPTBalance(channel)
-	case common.ChannelTypeAIGC2D:
+	case constant.ChannelTypeAIGC2D:
 		return updateChannelAIGC2DBalance(channel)
-	case common.ChannelTypeSiliconFlow:
+	case constant.ChannelTypeSiliconFlow:
 		return updateChannelSiliconFlowBalance(channel)
-	case common.ChannelTypeDeepSeek:
+	case constant.ChannelTypeDeepSeek:
 		return updateChannelDeepSeekBalance(channel)
-	case common.ChannelTypeOpenRouter:
+	case constant.ChannelTypeOpenRouter:
 		return updateChannelOpenRouterBalance(channel)
-	case common.ChannelTypeMoonshot:
+	case constant.ChannelTypeMoonshot:
 		return updateChannelMoonshotBalance(channel)
 	default:
 		return 0, errors.New("尚未实现")

+ 9 - 9
controller/channel-test.go

@@ -11,12 +11,12 @@ import (
 	"net/http/httptest"
 	"net/url"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/relay"
 	relaycommon "one-api/relay/common"
-	"one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/service"
 	"strconv"
@@ -31,19 +31,19 @@ import (
 
 func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
 	tik := time.Now()
-	if channel.Type == common.ChannelTypeMidjourney {
+	if channel.Type == constant.ChannelTypeMidjourney {
 		return errors.New("midjourney channel test is not supported"), nil
 	}
-	if channel.Type == common.ChannelTypeMidjourneyPlus {
-		return errors.New("midjourney plus channel test is not supported!!!"), nil
+	if channel.Type == constant.ChannelTypeMidjourneyPlus {
+		return errors.New("midjourney plus channel test is not supported"), nil
 	}
-	if channel.Type == common.ChannelTypeSunoAPI {
+	if channel.Type == constant.ChannelTypeSunoAPI {
 		return errors.New("suno channel test is not supported"), nil
 	}
-	if channel.Type == common.ChannelTypeKling {
+	if channel.Type == constant.ChannelTypeKling {
 		return errors.New("kling channel test is not supported"), nil
 	}
-	if channel.Type == common.ChannelTypeJimeng {
+	if channel.Type == constant.ChannelTypeJimeng {
 		return errors.New("jimeng channel test is not supported"), nil
 	}
 	w := httptest.NewRecorder()
@@ -56,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 		strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
 		strings.Contains(testModel, "bge-") || // bge 系列模型
 		strings.Contains(testModel, "embed") ||
-		channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
+		channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
 		requestPath = "/v1/embeddings" // 修改请求路径
 	}
 
@@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 	testModel = info.UpstreamModelName
 
-	apiType, _ := constant.ChannelType2APIType(channel.Type)
+	apiType, _ := common.ChannelType2APIType(channel.Type)
 	adaptor := relay.GetAdaptor(apiType)
 	if adaptor == nil {
 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil

+ 9 - 8
controller/channel.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/model"
 	"strconv"
 	"strings"
@@ -125,7 +126,7 @@ func GetAllChannels(c *gin.Context) {
 			order = "id desc"
 		}
 
-		err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
+		err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
 		if err != nil {
 			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
 			return
@@ -181,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
 		return
 	}
 
-	baseURL := common.ChannelBaseURLs[channel.Type]
+	baseURL := constant.ChannelBaseURLs[channel.Type]
 	if channel.GetBaseURL() != "" {
 		baseURL = channel.GetBaseURL()
 	}
 	url := fmt.Sprintf("%s/v1/models", baseURL)
 	switch channel.Type {
-	case common.ChannelTypeGemini:
+	case constant.ChannelTypeGemini:
 		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
-	case common.ChannelTypeAli:
+	case constant.ChannelTypeAli:
 		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
 	}
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
@@ -213,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
 	var ids []string
 	for _, model := range result.Data {
 		id := model.ID
-		if channel.Type == common.ChannelTypeGemini {
+		if channel.Type == constant.ChannelTypeGemini {
 			id = strings.TrimPrefix(id, "models/")
 		}
 		ids = append(ids, id)
@@ -388,7 +389,7 @@ func AddChannel(c *gin.Context) {
 	}
 	channel.CreatedTime = common.GetTimestamp()
 	keys := strings.Split(channel.Key, "\n")
-	if channel.Type == common.ChannelTypeVertexAi {
+	if channel.Type == constant.ChannelTypeVertexAi {
 		if channel.Other == "" {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
@@ -613,7 +614,7 @@ func UpdateChannel(c *gin.Context) {
 		})
 		return
 	}
-	if channel.Type == common.ChannelTypeVertexAi {
+	if channel.Type == constant.ChannelTypeVertexAi {
 		if channel.Other == "" {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
@@ -668,7 +669,7 @@ func FetchModels(c *gin.Context) {
 
 	baseURL := req.BaseURL
 	if baseURL == "" {
-		baseURL = common.ChannelBaseURLs[req.Type]
+		baseURL = constant.ChannelBaseURLs[req.Type]
 	}
 
 	client := &http.Client{}

+ 52 - 95
controller/model.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"github.com/samber/lo"
 	"net/http"
 	"one-api/common"
@@ -14,10 +15,7 @@ import (
 	"one-api/relay/channel/minimax"
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
-	relayconstant "one-api/relay/constant"
 	"one-api/setting"
-
-	"github.com/gin-gonic/gin"
 )
 
 // https://platform.openai.com/docs/api-reference/models/list
@@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels
 var openAIModelsMap map[string]dto.OpenAIModels
 var channelId2Models map[int][]string
 
-func getPermission() []dto.OpenAIModelPermission {
-	var permission []dto.OpenAIModelPermission
-	permission = append(permission, dto.OpenAIModelPermission{
-		Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
-		Object:             "model_permission",
-		Created:            1626777600,
-		AllowCreateEngine:  true,
-		AllowSampling:      true,
-		AllowLogprobs:      true,
-		AllowSearchIndices: false,
-		AllowView:          true,
-		AllowFineTuning:    false,
-		Organization:       "*",
-		Group:              nil,
-		IsBlocking:         false,
-	})
-	return permission
-}
-
 func init() {
 	// https://platform.openai.com/docs/models/model-endpoint-compatibility
-	permission := getPermission()
-	for i := 0; i < relayconstant.APITypeDummy; i++ {
-		if i == relayconstant.APITypeAIProxyLibrary {
+	for i := 0; i < constant.APITypeDummy; i++ {
+		if i == constant.APITypeAIProxyLibrary {
 			continue
 		}
 		adaptor := relay.GetAdaptor(i)
@@ -57,69 +35,51 @@ func init() {
 		modelNames := adaptor.GetModelList()
 		for _, modelName := range modelNames {
 			openAIModels = append(openAIModels, dto.OpenAIModels{
-				Id:         modelName,
-				Object:     "model",
-				Created:    1626777600,
-				OwnedBy:    channelName,
-				Permission: permission,
-				Root:       modelName,
-				Parent:     nil,
+				Id:      modelName,
+				Object:  "model",
+				Created: 1626777600,
+				OwnedBy: channelName,
 			})
 		}
 	}
 	for _, modelName := range ai360.ModelList {
 		openAIModels = append(openAIModels, dto.OpenAIModels{
-			Id:         modelName,
-			Object:     "model",
-			Created:    1626777600,
-			OwnedBy:    ai360.ChannelName,
-			Permission: permission,
-			Root:       modelName,
-			Parent:     nil,
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: ai360.ChannelName,
 		})
 	}
 	for _, modelName := range moonshot.ModelList {
 		openAIModels = append(openAIModels, dto.OpenAIModels{
-			Id:         modelName,
-			Object:     "model",
-			Created:    1626777600,
-			OwnedBy:    moonshot.ChannelName,
-			Permission: permission,
-			Root:       modelName,
-			Parent:     nil,
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: moonshot.ChannelName,
 		})
 	}
 	for _, modelName := range lingyiwanwu.ModelList {
 		openAIModels = append(openAIModels, dto.OpenAIModels{
-			Id:         modelName,
-			Object:     "model",
-			Created:    1626777600,
-			OwnedBy:    lingyiwanwu.ChannelName,
-			Permission: permission,
-			Root:       modelName,
-			Parent:     nil,
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: lingyiwanwu.ChannelName,
 		})
 	}
 	for _, modelName := range minimax.ModelList {
 		openAIModels = append(openAIModels, dto.OpenAIModels{
-			Id:         modelName,
-			Object:     "model",
-			Created:    1626777600,
-			OwnedBy:    minimax.ChannelName,
-			Permission: permission,
-			Root:       modelName,
-			Parent:     nil,
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: minimax.ChannelName,
 		})
 	}
 	for modelName, _ := range constant.MidjourneyModel2Action {
 		openAIModels = append(openAIModels, dto.OpenAIModels{
-			Id:         modelName,
-			Object:     "model",
-			Created:    1626777600,
-			OwnedBy:    "midjourney",
-			Permission: permission,
-			Root:       modelName,
-			Parent:     nil,
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: "midjourney",
 		})
 	}
 	openAIModelsMap = make(map[string]dto.OpenAIModels)
@@ -127,9 +87,9 @@ func init() {
 		openAIModelsMap[aiModel.Id] = aiModel
 	}
 	channelId2Models = make(map[int][]string)
-	for i := 1; i <= common.ChannelTypeDummy; i++ {
-		apiType, success := relayconstant.ChannelType2APIType(i)
-		if !success || apiType == relayconstant.APITypeAIProxyLibrary {
+	for i := 1; i <= constant.ChannelTypeDummy; i++ {
+		apiType, success := common.ChannelType2APIType(i)
+		if !success || apiType == constant.APITypeAIProxyLibrary {
 			continue
 		}
 		meta := &relaycommon.RelayInfo{ChannelType: i}
@@ -144,11 +104,10 @@ func init() {
 
 func ListModels(c *gin.Context) {
 	userOpenAiModels := make([]dto.OpenAIModels, 0)
-	permission := getPermission()
 
-	modelLimitEnable := c.GetBool("token_model_limit_enabled")
+	modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
 	if modelLimitEnable {
-		s, ok := c.Get("token_model_limit")
+		s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
 		var tokenModelLimit map[string]bool
 		if ok {
 			tokenModelLimit = s.(map[string]bool)
@@ -156,17 +115,16 @@ func ListModels(c *gin.Context) {
 			tokenModelLimit = map[string]bool{}
 		}
 		for allowModel, _ := range tokenModelLimit {
-			if _, ok := openAIModelsMap[allowModel]; ok {
-				userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
+			if oaiModel, ok := openAIModelsMap[allowModel]; ok {
+				oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
+				userOpenAiModels = append(userOpenAiModels, oaiModel)
 			} else {
 				userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
-					Id:         allowModel,
-					Object:     "model",
-					Created:    1626777600,
-					OwnedBy:    "custom",
-					Permission: permission,
-					Root:       allowModel,
-					Parent:     nil,
+					Id:                     allowModel,
+					Object:                 "model",
+					Created:                1626777600,
+					OwnedBy:                "custom",
+					SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
 				})
 			}
 		}
@@ -181,14 +139,14 @@ func ListModels(c *gin.Context) {
 			return
 		}
 		group := userGroup
-		tokenGroup := c.GetString("token_group")
+		tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
 		if tokenGroup != "" {
 			group = tokenGroup
 		}
 		var models []string
 		if tokenGroup == "auto" {
 			for _, autoGroup := range setting.AutoGroups {
-				groupModels := model.GetGroupModels(autoGroup)
+				groupModels := model.GetGroupEnabledModels(autoGroup)
 				for _, g := range groupModels {
 					if !common.StringsContains(models, g) {
 						models = append(models, g)
@@ -196,20 +154,19 @@ func ListModels(c *gin.Context) {
 				}
 			}
 		} else {
-			models = model.GetGroupModels(group)
+			models = model.GetGroupEnabledModels(group)
 		}
-		for _, s := range models {
-			if _, ok := openAIModelsMap[s]; ok {
-				userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
+		for _, modelName := range models {
+			if oaiModel, ok := openAIModelsMap[modelName]; ok {
+				oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
+				userOpenAiModels = append(userOpenAiModels, oaiModel)
 			} else {
 				userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
-					Id:         s,
-					Object:     "model",
-					Created:    1626777600,
-					OwnedBy:    "custom",
-					Permission: permission,
-					Root:       s,
-					Parent:     nil,
+					Id:                     modelName,
+					Object:                 "model",
+					Created:                1626777600,
+					OwnedBy:                "custom",
+					SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
 				})
 			}
 		}

+ 1 - 1
controller/playground.go

@@ -65,7 +65,7 @@ func Playground(c *gin.Context) {
 		return
 	}
 	middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
-	c.Set(constant.ContextKeyRequestStartTime, time.Now())
+	common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
 
 	// Write user context to ensure acceptUnsetRatio is available
 	userId := c.GetInt("id")

+ 4 - 4
controller/relay.go

@@ -8,12 +8,12 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	constant2 "one-api/constant"
 	"one-api/dto"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/relay"
-	"one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/service"
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 }
 
 func Relay(c *gin.Context) {
-	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
 	requestId := c.GetString(common.RequestIdKey)
 	group := c.GetString("group")
 	originalModel := c.GetString("original_model")
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
 		return
 	}
 
-	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
 	requestId := c.GetString(common.RequestIdKey)
 	group := c.GetString("group")
 	//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
 	}
 	if openaiErr.StatusCode == http.StatusBadRequest {
 		channelType := c.GetInt("channel_type")
-		if channelType == common.ChannelTypeAnthropic {
+		if channelType == constant.ChannelTypeAnthropic {
 			return true
 		}
 		return false

+ 1 - 1
controller/task_video.go

@@ -51,7 +51,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
 }
 
 func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
-	baseURL := common.ChannelBaseURLs[channel.Type]
+	baseURL := constant.ChannelBaseURLs[channel.Type]
 	if channel.GetBaseURL() != "" {
 		baseURL = channel.GetBaseURL()
 	}

+ 1 - 1
controller/user.go

@@ -487,7 +487,7 @@ func GetUserModels(c *gin.Context) {
 	groups := setting.GetUserUsableGroups(user.Group)
 	var models []string
 	for group := range groups {
-		for _, g := range model.GetGroupModels(group) {
+		for _, g := range model.GetGroupEnabledModels(group) {
 			if !common.StringsContains(models, g) {
 				models = append(models, g)
 			}

+ 6 - 21
dto/pricing.go

@@ -1,26 +1,11 @@
 package dto
 
-type OpenAIModelPermission struct {
-	Id                 string  `json:"id"`
-	Object             string  `json:"object"`
-	Created            int     `json:"created"`
-	AllowCreateEngine  bool    `json:"allow_create_engine"`
-	AllowSampling      bool    `json:"allow_sampling"`
-	AllowLogprobs      bool    `json:"allow_logprobs"`
-	AllowSearchIndices bool    `json:"allow_search_indices"`
-	AllowView          bool    `json:"allow_view"`
-	AllowFineTuning    bool    `json:"allow_fine_tuning"`
-	Organization       string  `json:"organization"`
-	Group              *string `json:"group"`
-	IsBlocking         bool    `json:"is_blocking"`
-}
+import "one-api/constant"
 
 type OpenAIModels struct {
-	Id         string                  `json:"id"`
-	Object     string                  `json:"object"`
-	Created    int                     `json:"created"`
-	OwnedBy    string                  `json:"owned_by"`
-	Permission []OpenAIModelPermission `json:"permission"`
-	Root       string                  `json:"root"`
-	Parent     *string                 `json:"parent"`
+	Id                     string                  `json:"id"`
+	Object                 string                  `json:"object"`
+	Created                int                     `json:"created"`
+	OwnedBy                string                  `json:"owned_by"`
+	SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
 }

+ 5 - 4
main.go

@@ -169,10 +169,8 @@ func InitResources() error {
 		common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
 	}
 
-	// 加载旧的(common)环境变量
-	common.InitCommonEnv()
-	// 加载constants的环境变量
-	constant.InitEnv()
+	// 加载环境变量
+	common.InitEnv()
 
 	// Initialize model settings
 	ratio_setting.InitRatioSettings()
@@ -193,6 +191,9 @@ func InitResources() error {
 	// Initialize options, should after model.InitDB()
 	model.InitOptionMap()
 
+	// 初始化模型
+	model.GetPricing()
+
 	// Initialize SQL Database
 	err = model.InitLogDB()
 	if err != nil {

+ 16 - 16
middleware/distributor.go

@@ -25,7 +25,7 @@ type ModelRequest struct {
 
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
-		allowIpsMap := c.GetStringMap("allow_ips")
+		allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
 		if len(allowIpsMap) != 0 {
 			clientIp := c.ClientIP()
 			if _, ok := allowIpsMap[clientIp]; !ok {
@@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) {
 			}
 		}
 		var channel *model.Channel
-		channelId, ok := c.Get("specific_channel_id")
+		channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
 		modelRequest, shouldSelectChannel, err := getModelRequest(c)
 		if err != nil {
 			abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
 			return
 		}
-		userGroup := c.GetString(constant.ContextKeyUserGroup)
-		tokenGroup := c.GetString("token_group")
+		userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
+		tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
 		if tokenGroup != "" {
 			// check common.UserUsableGroups[userGroup]
 			if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
 			}
 			userGroup = tokenGroup
 		}
-		c.Set(constant.ContextKeyUsingGroup, userGroup)
+		common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
@@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) {
 		} else {
 			// Select a channel for the user
 			// check token model mapping
-			modelLimitEnable := c.GetBool("token_model_limit_enabled")
+			modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
 			if modelLimitEnable {
-				s, ok := c.Get("token_model_limit")
+				s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
 				var tokenModelLimit map[string]bool
 				if ok {
 					tokenModelLimit = s.(map[string]bool)
@@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) {
 				}
 			}
 		}
-		c.Set(constant.ContextKeyRequestStartTime, time.Now())
+		common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
 		SetupContextForSelectedChannel(c, channel, modelRequest.Model)
 		c.Next()
 	}
@@ -261,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	c.Set("base_url", channel.GetBaseURL())
 	// TODO: api_version统一
 	switch channel.Type {
-	case common.ChannelTypeAzure:
+	case constant.ChannelTypeAzure:
 		c.Set("api_version", channel.Other)
-	case common.ChannelTypeVertexAi:
+	case constant.ChannelTypeVertexAi:
 		c.Set("region", channel.Other)
-	case common.ChannelTypeXunfei:
+	case constant.ChannelTypeXunfei:
 		c.Set("api_version", channel.Other)
-	case common.ChannelTypeGemini:
+	case constant.ChannelTypeGemini:
 		c.Set("api_version", channel.Other)
-	case common.ChannelTypeAli:
+	case constant.ChannelTypeAli:
 		c.Set("plugin", channel.Other)
-	case common.ChannelCloudflare:
+	case constant.ChannelCloudflare:
 		c.Set("api_version", channel.Other)
-	case common.ChannelTypeMokaAI:
+	case constant.ChannelTypeMokaAI:
 		c.Set("api_version", channel.Other)
-	case common.ChannelTypeCoze:
+	case constant.ChannelTypeCoze:
 		c.Set("bot_id", channel.Other)
 	}
 }

+ 2 - 2
middleware/model-rate-limit.go

@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
 		successMaxCount := setting.ModelRequestRateLimitSuccessCount
 
 		// 获取分组
-		group := c.GetString("token_group")
+		group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
 		if group == "" {
-			group = c.GetString(constant.ContextKeyUserGroup)
+			group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
 		}
 
 		//获取分组的限流配置

+ 20 - 5
model/ability.go

@@ -21,7 +21,22 @@ type Ability struct {
 	Tag       *string `json:"tag" gorm:"index"`
 }
 
-func GetGroupModels(group string) []string {
+type AbilityWithChannel struct {
+	Ability
+	ChannelType int `json:"channel_type"`
+}
+
+func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
+	var abilities []AbilityWithChannel
+	err := DB.Table("abilities").
+		Select("abilities.*, channels.type as channel_type").
+		Joins("left join channels on abilities.channel_id = channels.id").
+		Where("abilities.enabled = ?", true).
+		Scan(&abilities).Error
+	return abilities, err
+}
+
+func GetGroupEnabledModels(group string) []string {
 	var models []string
 	// Find distinct models
 	DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
 	var priorities []int
 	err := DB.Model(&Ability{}).
 		Select("DISTINCT(priority)").
-		Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
+		Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
 		Order("priority DESC").              // 按优先级降序排序
 		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
 
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
 }
 
 func getChannelQuery(group string, model string, retry int) *gorm.DB {
-	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
-	channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
+	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
+	channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
 	if retry != 0 {
 		priority, err := getPriority(group, model, retry)
 		if err != nil {
 			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
 		} else {
-			channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
+			channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
 		}
 	}
 

+ 78 - 32
model/pricing.go

@@ -1,20 +1,24 @@
 package model
 
 import (
+	"fmt"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/setting/ratio_setting"
+	"one-api/types"
 	"sync"
 	"time"
 )
 
 type Pricing struct {
-	ModelName       string   `json:"model_name"`
-	QuotaType       int      `json:"quota_type"`
-	ModelRatio      float64  `json:"model_ratio"`
-	ModelPrice      float64  `json:"model_price"`
-	OwnerBy         string   `json:"owner_by"`
-	CompletionRatio float64  `json:"completion_ratio"`
-	EnableGroup     []string `json:"enable_groups,omitempty"`
+	ModelName              string                  `json:"model_name"`
+	QuotaType              int                     `json:"quota_type"`
+	ModelRatio             float64                 `json:"model_ratio"`
+	ModelPrice             float64                 `json:"model_price"`
+	OwnerBy                string                  `json:"owner_by"`
+	CompletionRatio        float64                 `json:"completion_ratio"`
+	EnableGroup            []string                `json:"enable_groups"`
+	SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
 }
 
 var (
@@ -23,47 +27,89 @@ var (
 	updatePricingLock  sync.Mutex
 )
 
-func GetPricing() []Pricing {
-	updatePricingLock.Lock()
-	defer updatePricingLock.Unlock()
+var (
+	modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+	modelSupportEndpointsLock = sync.RWMutex{}
+)
 
+func GetPricing() []Pricing {
 	if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
-		updatePricing()
+		updatePricingLock.Lock()
+		defer updatePricingLock.Unlock()
+		// Double check after acquiring the lock
+		if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
+			modelSupportEndpointsLock.Lock()
+			defer modelSupportEndpointsLock.Unlock()
+			updatePricing()
+		}
 	}
-	//if group != "" {
-	//	userPricingMap := make([]Pricing, 0)
-	//	models := GetGroupModels(group)
-	//	for _, pricing := range pricingMap {
-	//		if !common.StringsContains(models, pricing.ModelName) {
-	//			pricing.Available = false
-	//		}
-	//		userPricingMap = append(userPricingMap, pricing)
-	//	}
-	//	return userPricingMap
-	//}
 	return pricingMap
 }
 
+func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
+	if model == "" {
+		return make([]constant.EndpointType, 0)
+	}
+	modelSupportEndpointsLock.RLock()
+	defer modelSupportEndpointsLock.RUnlock()
+	if endpoints, ok := modelSupportEndpointTypes[model]; ok {
+		return endpoints
+	}
+	return make([]constant.EndpointType, 0)
+}
+
 func updatePricing() {
 	//modelRatios := common.GetModelRatios()
-	enableAbilities := GetAllEnableAbilities()
-	modelGroupsMap := make(map[string][]string)
+	enableAbilities, err := GetAllEnableAbilityWithChannels()
+	if err != nil {
+		common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
+		return
+	}
+	modelGroupsMap := make(map[string]*types.Set[string])
+
 	for _, ability := range enableAbilities {
-		groups := modelGroupsMap[ability.Model]
-		if groups == nil {
-			groups = make([]string, 0)
+		groups, ok := modelGroupsMap[ability.Model]
+		if !ok {
+			groups = types.NewSet[string]()
+			modelGroupsMap[ability.Model] = groups
 		}
-		if !common.StringsContains(groups, ability.Group) {
-			groups = append(groups, ability.Group)
+		groups.Add(ability.Group)
+	}
+
+	//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
+	modelSupportEndpointsStr := make(map[string][]string)
+
+	for _, ability := range enableAbilities {
+		endpoints, ok := modelSupportEndpointsStr[ability.Model]
+		if !ok {
+			endpoints = make([]string, 0)
+			modelSupportEndpointsStr[ability.Model] = endpoints
+		}
+		channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
+		for _, channelType := range channelTypes {
+			if !common.StringsContains(endpoints, string(channelType)) {
+				endpoints = append(endpoints, string(channelType))
+			}
+		}
+		modelSupportEndpointsStr[ability.Model] = endpoints
+	}
+
+	modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+	for model, endpoints := range modelSupportEndpointsStr {
+		supportedEndpoints := make([]constant.EndpointType, 0)
+		for _, endpointStr := range endpoints {
+			endpointType := constant.EndpointType(endpointStr)
+			supportedEndpoints = append(supportedEndpoints, endpointType)
 		}
-		modelGroupsMap[ability.Model] = groups
+		modelSupportEndpointTypes[model] = supportedEndpoints
 	}
 
 	pricingMap = make([]Pricing, 0)
 	for model, groups := range modelGroupsMap {
 		pricing := Pricing{
-			ModelName:   model,
-			EnableGroup: groups,
+			ModelName:              model,
+			EnableGroup:            groups.Items(),
+			SupportedEndpointTypes: modelSupportEndpointTypes[model],
 		}
 		modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
 		if findPrice {

+ 1 - 1
model/token_cache.go

@@ -10,7 +10,7 @@ import (
 func cacheSetToken(token Token) error {
 	key := common.GenerateHMAC(token.Key)
 	token.Clean()
-	err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
+	err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
 	if err != nil {
 		return err
 	}

+ 7 - 7
model/user_cache.go

@@ -24,12 +24,12 @@ type UserBase struct {
 }
 
 func (user *UserBase) WriteContext(c *gin.Context) {
-	c.Set(constant.ContextKeyUserGroup, user.Group)
-	c.Set(constant.ContextKeyUserQuota, user.Quota)
-	c.Set(constant.ContextKeyUserStatus, user.Status)
-	c.Set(constant.ContextKeyUserEmail, user.Email)
-	c.Set("username", user.Username)
-	c.Set(constant.ContextKeyUserSetting, user.GetSetting())
+	common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
+	common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
+	common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
+	common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
+	common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
+	common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
 }
 
 func (user *UserBase) GetSetting() map[string]interface{} {
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
 	return common.RedisHSetObj(
 		getUserCacheKey(user.Id),
 		user.ToBaseUser(),
-		time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
+		time.Duration(common.RedisKeyCacheSeconds())*time.Second,
 	)
 }
 

+ 42 - 43
relay/channel/openai/adaptor.go

@@ -9,8 +9,7 @@ import (
 	"mime/multipart"
 	"net/http"
 	"net/textproto"
-	"one-api/common"
-	constant2 "one-api/constant"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ai360"
@@ -21,7 +20,7 @@ import (
 	"one-api/relay/channel/xinference"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/common_handler"
-	"one-api/relay/constant"
+	relayconstant "one-api/relay/constant"
 	"one-api/service"
 	"path/filepath"
 	"strings"
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	a.ChannelType = info.ChannelType
 
 	// initialize ThinkingContentInfo when thinking_to_content is enabled
-	if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
+	if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
 		info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
 			IsFirstThinkingContent:  true,
 			SendLastThinkingContent: false,
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
 	}
-	if info.RelayMode == constant.RelayModeRealtime {
+	if info.RelayMode == relayconstant.RelayModeRealtime {
 		if strings.HasPrefix(info.BaseUrl, "https://") {
 			baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
 			baseUrl = "wss://" + baseUrl
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		}
 	}
 	switch info.ChannelType {
-	case common.ChannelTypeAzure:
+	case constant.ChannelTypeAzure:
 		apiVersion := info.ApiVersion
 		if apiVersion == "" {
-			apiVersion = constant2.AzureDefaultAPIVersion
+			apiVersion = constant.AzureDefaultAPIVersion
 		}
 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 		requestURL := strings.Split(info.RequestURLPath, "?")[0]
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		task := strings.TrimPrefix(requestURL, "/v1/")
 
 		// 特殊处理 responses API
-		if info.RelayMode == constant.RelayModeResponses {
+		if info.RelayMode == relayconstant.RelayModeResponses {
 			requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
 			return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
 		}
 
 		model_ := info.UpstreamModelName
 		// 2025年5月10日后创建的渠道不移除.
-		if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
+		if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
 			model_ = strings.Replace(model_, ".", "", -1)
 		}
 		// https://github.com/songquanpeng/one-api/issues/67
 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
-		if info.RelayMode == constant.RelayModeRealtime {
+		if info.RelayMode == relayconstant.RelayModeRealtime {
 			requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
 		}
 		return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
-	case common.ChannelTypeMiniMax:
+	case constant.ChannelTypeMiniMax:
 		return minimax.GetRequestURL(info)
-	case common.ChannelTypeCustom:
+	case constant.ChannelTypeCustom:
 		url := info.BaseUrl
 		url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
 		return url, nil
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, header)
-	if info.ChannelType == common.ChannelTypeAzure {
+	if info.ChannelType == constant.ChannelTypeAzure {
 		header.Set("api-key", info.ApiKey)
 		return nil
 	}
-	if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
+	if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
 		header.Set("OpenAI-Organization", info.Organization)
 	}
-	if info.RelayMode == constant.RelayModeRealtime {
+	if info.RelayMode == relayconstant.RelayModeRealtime {
 		swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
 		if swp != "" {
 			items := []string{
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
 	} else {
 		header.Set("Authorization", "Bearer "+info.ApiKey)
 	}
-	if info.ChannelType == common.ChannelTypeOpenRouter {
+	if info.ChannelType == constant.ChannelTypeOpenRouter {
 		header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
 		header.Set("X-Title", "New API")
 	}
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
+	if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
 		request.StreamOptions = nil
 	}
-	if info.ChannelType == common.ChannelTypeOpenRouter {
+	if info.ChannelType == constant.ChannelTypeOpenRouter {
 		if len(request.Usage) == 0 {
 			request.Usage = json.RawMessage(`{"include":true}`)
 		}
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	a.ResponseFormat = request.ResponseFormat
-	if info.RelayMode == constant.RelayModeAudioSpeech {
+	if info.RelayMode == relayconstant.RelayModeAudioSpeech {
 		jsonData, err := json.Marshal(request)
 		if err != nil {
 			return nil, fmt.Errorf("error marshalling object: %w", err)
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	switch info.RelayMode {
-	case constant.RelayModeImagesEdits:
+	case relayconstant.RelayModeImagesEdits:
 
 		var requestBody bytes.Buffer
 		writer := multipart.NewWriter(&requestBody)
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	if info.RelayMode == constant.RelayModeAudioTranscription ||
-		info.RelayMode == constant.RelayModeAudioTranslation ||
-		info.RelayMode == constant.RelayModeImagesEdits {
+	if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
+		info.RelayMode == relayconstant.RelayModeAudioTranslation ||
+		info.RelayMode == relayconstant.RelayModeImagesEdits {
 		return channel.DoFormRequest(a, c, info, requestBody)
-	} else if info.RelayMode == constant.RelayModeRealtime {
+	} else if info.RelayMode == relayconstant.RelayModeRealtime {
 		return channel.DoWssRequest(a, c, info, requestBody)
 	} else {
 		return channel.DoApiRequest(a, c, info, requestBody)
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
-	case constant.RelayModeRealtime:
+	case relayconstant.RelayModeRealtime:
 		err, usage = OpenaiRealtimeHandler(c, info)
-	case constant.RelayModeAudioSpeech:
+	case relayconstant.RelayModeAudioSpeech:
 		err, usage = OpenaiTTSHandler(c, resp, info)
-	case constant.RelayModeAudioTranslation:
+	case relayconstant.RelayModeAudioTranslation:
 		fallthrough
-	case constant.RelayModeAudioTranscription:
+	case relayconstant.RelayModeAudioTranscription:
 		err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
-	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+	case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
 		err, usage = OpenaiHandlerWithUsage(c, resp, info)
-	case constant.RelayModeRerank:
+	case relayconstant.RelayModeRerank:
 		err, usage = common_handler.RerankHandler(c, info, resp)
-	case constant.RelayModeResponses:
+	case relayconstant.RelayModeResponses:
 		if info.IsStream {
 			err, usage = OaiResponsesStreamHandler(c, resp, info)
 		} else {
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 
 func (a *Adaptor) GetModelList() []string {
 	switch a.ChannelType {
-	case common.ChannelType360:
+	case constant.ChannelType360:
 		return ai360.ModelList
-	case common.ChannelTypeMoonshot:
+	case constant.ChannelTypeMoonshot:
 		return moonshot.ModelList
-	case common.ChannelTypeLingYiWanWu:
+	case constant.ChannelTypeLingYiWanWu:
 		return lingyiwanwu.ModelList
-	case common.ChannelTypeMiniMax:
+	case constant.ChannelTypeMiniMax:
 		return minimax.ModelList
-	case common.ChannelTypeXinference:
+	case constant.ChannelTypeXinference:
 		return xinference.ModelList
-	case common.ChannelTypeOpenRouter:
+	case constant.ChannelTypeOpenRouter:
 		return openrouter.ModelList
 	default:
 		return ModelList
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
 
 func (a *Adaptor) GetChannelName() string {
 	switch a.ChannelType {
-	case common.ChannelType360:
+	case constant.ChannelType360:
 		return ai360.ChannelName
-	case common.ChannelTypeMoonshot:
+	case constant.ChannelTypeMoonshot:
 		return moonshot.ChannelName
-	case common.ChannelTypeLingYiWanWu:
+	case constant.ChannelTypeLingYiWanWu:
 		return lingyiwanwu.ChannelName
-	case common.ChannelTypeMiniMax:
+	case constant.ChannelTypeMiniMax:
 		return minimax.ChannelName
-	case common.ChannelTypeXinference:
+	case constant.ChannelTypeXinference:
 		return xinference.ChannelName
-	case common.ChannelTypeOpenRouter:
+	case constant.ChannelTypeOpenRouter:
 		return openrouter.ChannelName
 	default:
 		return ChannelName

+ 1 - 1
relay/channel/openai/relay-openai.go

@@ -168,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	} else {
-		if info.ChannelType == common.ChannelTypeDeepSeek {
+		if info.ChannelType == constant.ChannelTypeDeepSeek {
 			if usage.PromptCacheHitTokens != 0 {
 				usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
 			}

+ 33 - 33
relay/common/relay_info.go

@@ -113,17 +113,17 @@ type RelayInfo struct {
 
 // 定义支持流式选项的通道类型
 var streamSupportedChannels = map[int]bool{
-	common.ChannelTypeOpenAI:     true,
-	common.ChannelTypeAnthropic:  true,
-	common.ChannelTypeAws:        true,
-	common.ChannelTypeGemini:     true,
-	common.ChannelCloudflare:     true,
-	common.ChannelTypeAzure:      true,
-	common.ChannelTypeVolcEngine: true,
-	common.ChannelTypeOllama:     true,
-	common.ChannelTypeXai:        true,
-	common.ChannelTypeDeepSeek:   true,
-	common.ChannelTypeBaiduV2:    true,
+	constant.ChannelTypeOpenAI:     true,
+	constant.ChannelTypeAnthropic:  true,
+	constant.ChannelTypeAws:        true,
+	constant.ChannelTypeGemini:     true,
+	constant.ChannelCloudflare:     true,
+	constant.ChannelTypeAzure:      true,
+	constant.ChannelTypeVolcEngine: true,
+	constant.ChannelTypeOllama:     true,
+	constant.ChannelTypeXai:        true,
+	constant.ChannelTypeDeepSeek:   true,
+	constant.ChannelTypeBaiduV2:    true,
 }
 
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
 }
 
 func GenRelayInfo(c *gin.Context) *RelayInfo {
-	channelType := c.GetInt("channel_type")
-	channelId := c.GetInt("channel_id")
-	channelSetting := c.GetStringMap("channel_setting")
-	paramOverride := c.GetStringMap("param_override")
-
-	tokenId := c.GetInt("token_id")
-	tokenKey := c.GetString("token_key")
-	userId := c.GetInt("id")
-	tokenUnlimited := c.GetBool("token_unlimited_quota")
-	startTime := c.GetTime(constant.ContextKeyRequestStartTime)
+	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+	channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
+	channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
+	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
+
+	tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
+	tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
+	userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
+	tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
+	startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
 	// firstResponseTime = time.Now() - 1 second
 
-	apiType, _ := relayconstant.ChannelType2APIType(channelType)
+	apiType, _ := common.ChannelType2APIType(channelType)
 
 	info := &RelayInfo{
-		UserQuota:         c.GetInt(constant.ContextKeyUserQuota),
-		UserSetting:       c.GetStringMap(constant.ContextKeyUserSetting),
-		UserEmail:         c.GetString(constant.ContextKeyUserEmail),
+		UserQuota:         common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
+		UserSetting:       common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
+		UserEmail:         common.GetContextKeyString(c, constant.ContextKeyUserEmail),
 		isFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
-		BaseUrl:           c.GetString("base_url"),
+		BaseUrl:           common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
 		RequestURLPath:    c.Request.URL.String(),
 		ChannelType:       channelType,
 		ChannelId:         channelId,
 		TokenId:           tokenId,
 		TokenKey:          tokenKey,
 		UserId:            userId,
-		UsingGroup:        c.GetString(constant.ContextKeyUsingGroup),
-		UserGroup:         c.GetString(constant.ContextKeyUserGroup),
+		UsingGroup:        common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
+		UserGroup:         common.GetContextKeyString(c, constant.ContextKeyUserGroup),
 		TokenUnlimited:    tokenUnlimited,
 		StartTime:         startTime,
 		FirstResponseTime: startTime.Add(-time.Second),
-		OriginModelName:   c.GetString("original_model"),
-		UpstreamModelName: c.GetString("original_model"),
+		OriginModelName:   common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+		UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
 		//RecodeModelName:   c.GetString("original_model"),
 		IsModelMapped:     false,
 		ApiType:           apiType,
@@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		info.RequestURLPath = "/v1" + info.RequestURLPath
 	}
 	if info.BaseUrl == "" {
-		info.BaseUrl = common.ChannelBaseURLs[channelType]
+		info.BaseUrl = constant.ChannelBaseURLs[channelType]
 	}
-	if info.ChannelType == common.ChannelTypeAzure {
+	if info.ChannelType == constant.ChannelTypeAzure {
 		info.ApiVersion = GetAPIVersion(c)
 	}
-	if info.ChannelType == common.ChannelTypeVertexAi {
+	if info.ChannelType == constant.ChannelTypeVertexAi {
 		info.ApiVersion = c.GetString("region")
 	}
 	if streamSupportedChannels[info.ChannelType] {

+ 3 - 3
relay/common/relay_utils.go

@@ -6,7 +6,7 @@ import (
 	_ "image/gif"
 	_ "image/jpeg"
 	_ "image/png"
-	"one-api/common"
+	"one-api/constant"
 	"strings"
 )
 
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
 
 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
 		switch channelType {
-		case common.ChannelTypeOpenAI:
+		case constant.ChannelTypeOpenAI:
 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
-		case common.ChannelTypeAzure:
+		case constant.ChannelTypeAzure:
 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
 		}
 	}

+ 2 - 1
relay/common_handler/rerank.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel/xinference"
 	relaycommon "one-api/relay/common"
@@ -21,7 +22,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		println("reranker response body: ", string(responseBody))
 	}
 	var jinaResp dto.RerankResponse
-	if info.ChannelType == common.ChannelTypeXinference {
+	if info.ChannelType == constant.ChannelTypeXinference {
 		var xinRerankResponse xinference.XinRerankResponse
 		err = common.UnmarshalJson(responseBody, &xinRerankResponse)
 		if err != nil {

+ 0 - 106
relay/constant/api_type.go

@@ -1,106 +0,0 @@
-package constant
-
-import (
-	"one-api/common"
-)
-
-const (
-	APITypeOpenAI = iota
-	APITypeAnthropic
-	APITypePaLM
-	APITypeBaidu
-	APITypeZhipu
-	APITypeAli
-	APITypeXunfei
-	APITypeAIProxyLibrary
-	APITypeTencent
-	APITypeGemini
-	APITypeZhipuV4
-	APITypeOllama
-	APITypePerplexity
-	APITypeAws
-	APITypeCohere
-	APITypeDify
-	APITypeJina
-	APITypeCloudflare
-	APITypeSiliconFlow
-	APITypeVertexAi
-	APITypeMistral
-	APITypeDeepSeek
-	APITypeMokaAI
-	APITypeVolcEngine
-	APITypeBaiduV2
-	APITypeOpenRouter
-	APITypeXinference
-	APITypeXai
-	APITypeCoze
-	APITypeDummy // this one is only for count, do not add any channel after this
-)
-
-func ChannelType2APIType(channelType int) (int, bool) {
-	apiType := -1
-	switch channelType {
-	case common.ChannelTypeOpenAI:
-		apiType = APITypeOpenAI
-	case common.ChannelTypeAnthropic:
-		apiType = APITypeAnthropic
-	case common.ChannelTypeBaidu:
-		apiType = APITypeBaidu
-	case common.ChannelTypePaLM:
-		apiType = APITypePaLM
-	case common.ChannelTypeZhipu:
-		apiType = APITypeZhipu
-	case common.ChannelTypeAli:
-		apiType = APITypeAli
-	case common.ChannelTypeXunfei:
-		apiType = APITypeXunfei
-	case common.ChannelTypeAIProxyLibrary:
-		apiType = APITypeAIProxyLibrary
-	case common.ChannelTypeTencent:
-		apiType = APITypeTencent
-	case common.ChannelTypeGemini:
-		apiType = APITypeGemini
-	case common.ChannelTypeZhipu_v4:
-		apiType = APITypeZhipuV4
-	case common.ChannelTypeOllama:
-		apiType = APITypeOllama
-	case common.ChannelTypePerplexity:
-		apiType = APITypePerplexity
-	case common.ChannelTypeAws:
-		apiType = APITypeAws
-	case common.ChannelTypeCohere:
-		apiType = APITypeCohere
-	case common.ChannelTypeDify:
-		apiType = APITypeDify
-	case common.ChannelTypeJina:
-		apiType = APITypeJina
-	case common.ChannelCloudflare:
-		apiType = APITypeCloudflare
-	case common.ChannelTypeSiliconFlow:
-		apiType = APITypeSiliconFlow
-	case common.ChannelTypeVertexAi:
-		apiType = APITypeVertexAi
-	case common.ChannelTypeMistral:
-		apiType = APITypeMistral
-	case common.ChannelTypeDeepSeek:
-		apiType = APITypeDeepSeek
-	case common.ChannelTypeMokaAI:
-		apiType = APITypeMokaAI
-	case common.ChannelTypeVolcEngine:
-		apiType = APITypeVolcEngine
-	case common.ChannelTypeBaiduV2:
-		apiType = APITypeBaiduV2
-	case common.ChannelTypeOpenRouter:
-		apiType = APITypeOpenRouter
-	case common.ChannelTypeXinference:
-		apiType = APITypeXinference
-	case common.ChannelTypeXai:
-		apiType = APITypeXai
-	case common.ChannelTypeCoze:
-		apiType = APITypeCoze
-	}
-	if apiType == -1 {
-		return APITypeOpenAI, false
-	}
-	return apiType, true
-}

+ 1 - 2
relay/image_handler.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
@@ -17,8 +18,6 @@ import (
 	"one-api/setting"
 	"strings"
 
-	"one-api/relay/constant"
-
 	"github.com/gin-gonic/gin"
 )
 

+ 1 - 1
relay/relay_adaptor.go

@@ -1,6 +1,7 @@
 package relay
 
 import (
+	"one-api/constant"
 	commonconstant "one-api/constant"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ali"
@@ -32,7 +33,6 @@ import (
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/zhipu"
 	"one-api/relay/channel/zhipu_4v"
-	"one-api/relay/constant"
 )
 
 func GetAdaptor(apiType int) channel.Adaptor {

+ 2 - 1
service/channel.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	"one-api/setting/operation_setting"
@@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
 	}
 	if err.StatusCode == http.StatusForbidden {
 		switch channelType {
-		case common.ChannelTypeGemini:
+		case constant.ChannelTypeGemini:
 			return true
 		}
 	}

+ 2 - 1
service/convert.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel/openrouter"
 	relaycommon "one-api/relay/common"
@@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
 		Stream:      claudeRequest.Stream,
 	}
 
-	isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
+	isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
 
 	if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
 		if isOpenRouter {

+ 3 - 3
service/quota.go

@@ -6,7 +6,7 @@ import (
 	"log"
 	"math"
 	"one-api/common"
-	constant2 "one-api/constant"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
@@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	cacheCreationRatio := priceData.CacheCreationRatio
 	cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
 
-	if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
+	if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
 		promptTokens -= cacheTokens
 		if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
 			maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
@@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
 	gopool.Go(func() {
 		userSetting := relayInfo.UserSetting
 		threshold := common.QuotaRemindThreshold
-		if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
+		if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
 			threshold = int(userCustomThreshold.(float64))
 		}
 

+ 1 - 1
service/token_counter.go

@@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
 	if !constant.GetMediaToken {
 		return 3 * baseTokens, nil
 	}
-	if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
+	if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
 		return 3 * baseTokens, nil
 	}
 	var config image.Config

+ 42 - 0
types/set.go

@@ -0,0 +1,42 @@
+package types
+
+type Set[T comparable] struct {
+	items map[T]struct{}
+}
+
+// NewSet 创建并返回一个新的 Set
+func NewSet[T comparable]() *Set[T] {
+	return &Set[T]{
+		items: make(map[T]struct{}),
+	}
+}
+
+func (s *Set[T]) Add(item T) {
+	s.items[item] = struct{}{}
+}
+
+// Remove 从 Set 中移除一个元素
+func (s *Set[T]) Remove(item T) {
+	delete(s.items, item)
+}
+
+// Contains 检查 Set 是否包含某个元素
+func (s *Set[T]) Contains(item T) bool {
+	_, exists := s.items[item]
+	return exists
+}
+
+// Len 返回 Set 中元素的数量
+func (s *Set[T]) Len() int {
+	return len(s.items)
+}
+
+// Items 返回 Set 中所有元素组成的切片
+// 注意:由于 map 的无序性,返回的切片元素顺序是随机的
+func (s *Set[T]) Items() []T {
+	items := make([]T, 0, s.Len())
+	for item := range s.items {
+		items = append(items, item)
+	}
+	return items
+}