package controller import ( "fmt" "github.com/gin-gonic/gin" "github.com/samber/lo" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" "one-api/setting" ) // https://platform.openai.com/docs/api-reference/models/list var openAIModels []dto.OpenAIModels var openAIModelsMap map[string]dto.OpenAIModels var channelId2Models map[int][]string func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility for i := 0; i < constant.APITypeDummy; i++ { if i == constant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: channelName, }) } } for _, modelName := range ai360.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: ai360.ChannelName, }) } for _, modelName := range moonshot.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: moonshot.ChannelName, }) } for _, modelName := range lingyiwanwu.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: lingyiwanwu.ChannelName, }) } for _, modelName := range minimax.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: minimax.ChannelName, }) } for modelName, _ := range constant.MidjourneyModel2Action { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: "midjourney", }) } openAIModelsMap = make(map[string]dto.OpenAIModels) for _, aiModel := range openAIModels { openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) for i := 1; i <= constant.ChannelTypeDummy; i++ { apiType, success := common.ChannelType2APIType(i) if !success || apiType == constant.APITypeAIProxyLibrary { continue } meta := &relaycommon.RelayInfo{ChannelType: i} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string { return m.Id }) } func ListModels(c *gin.Context) { userOpenAiModels := make([]dto.OpenAIModels, 0) modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) } else { tokenModelLimit = map[string]bool{} } for allowModel, _ := range tokenModelLimit { if oaiModel, ok := openAIModelsMap[allowModel]; ok { oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel) userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Id: allowModel, Object: "model", Created: 1626777600, OwnedBy: "custom", SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel), }) } } } else { userId := c.GetInt("id") userGroup, err := model.GetUserGroup(userId, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "get user group failed", }) return } group := userGroup tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if tokenGroup != "" { group = tokenGroup } var models []string if tokenGroup == "auto" { for _, autoGroup := range setting.AutoGroups { groupModels := model.GetGroupEnabledModels(autoGroup) for _, g := range groupModels { if !common.StringsContains(models, g) { models = append(models, g) } } } } else { models = model.GetGroupEnabledModels(group) } for _, modelName := range models { if oaiModel, ok := openAIModelsMap[modelName]; ok { oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName) userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: "custom", SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName), }) } } } c.JSON(200, gin.H{ "success": true, "data": userOpenAiModels, }) } func ChannelListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": openAIModels, }) } func DashboardListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": channelId2Models, }) } func EnabledListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": model.GetEnabledModels(), }) } func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if aiModel, ok := openAIModelsMap[modelId]; ok { c.JSON(200, aiModel) } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", Code: "model_not_found", } c.JSON(200, gin.H{ "error": openAIError, }) } }