model.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "github.com/samber/lo"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. "one-api/model"
  11. "one-api/relay"
  12. "one-api/relay/channel/ai360"
  13. "one-api/relay/channel/lingyiwanwu"
  14. "one-api/relay/channel/minimax"
  15. "one-api/relay/channel/moonshot"
  16. relaycommon "one-api/relay/common"
  17. "one-api/setting"
  18. )
  19. // https://platform.openai.com/docs/api-reference/models/list
  20. var openAIModels []dto.OpenAIModels
  21. var openAIModelsMap map[string]dto.OpenAIModels
  22. var channelId2Models map[int][]string
  23. func init() {
  24. // https://platform.openai.com/docs/models/model-endpoint-compatibility
  25. for i := 0; i < constant.APITypeDummy; i++ {
  26. if i == constant.APITypeAIProxyLibrary {
  27. continue
  28. }
  29. adaptor := relay.GetAdaptor(i)
  30. channelName := adaptor.GetChannelName()
  31. modelNames := adaptor.GetModelList()
  32. for _, modelName := range modelNames {
  33. openAIModels = append(openAIModels, dto.OpenAIModels{
  34. Id: modelName,
  35. Object: "model",
  36. Created: 1626777600,
  37. OwnedBy: channelName,
  38. })
  39. }
  40. }
  41. for _, modelName := range ai360.ModelList {
  42. openAIModels = append(openAIModels, dto.OpenAIModels{
  43. Id: modelName,
  44. Object: "model",
  45. Created: 1626777600,
  46. OwnedBy: ai360.ChannelName,
  47. })
  48. }
  49. for _, modelName := range moonshot.ModelList {
  50. openAIModels = append(openAIModels, dto.OpenAIModels{
  51. Id: modelName,
  52. Object: "model",
  53. Created: 1626777600,
  54. OwnedBy: moonshot.ChannelName,
  55. })
  56. }
  57. for _, modelName := range lingyiwanwu.ModelList {
  58. openAIModels = append(openAIModels, dto.OpenAIModels{
  59. Id: modelName,
  60. Object: "model",
  61. Created: 1626777600,
  62. OwnedBy: lingyiwanwu.ChannelName,
  63. })
  64. }
  65. for _, modelName := range minimax.ModelList {
  66. openAIModels = append(openAIModels, dto.OpenAIModels{
  67. Id: modelName,
  68. Object: "model",
  69. Created: 1626777600,
  70. OwnedBy: minimax.ChannelName,
  71. })
  72. }
  73. for modelName, _ := range constant.MidjourneyModel2Action {
  74. openAIModels = append(openAIModels, dto.OpenAIModels{
  75. Id: modelName,
  76. Object: "model",
  77. Created: 1626777600,
  78. OwnedBy: "midjourney",
  79. })
  80. }
  81. openAIModelsMap = make(map[string]dto.OpenAIModels)
  82. for _, aiModel := range openAIModels {
  83. openAIModelsMap[aiModel.Id] = aiModel
  84. }
  85. channelId2Models = make(map[int][]string)
  86. for i := 1; i <= constant.ChannelTypeDummy; i++ {
  87. apiType, success := common.ChannelType2APIType(i)
  88. if !success || apiType == constant.APITypeAIProxyLibrary {
  89. continue
  90. }
  91. meta := &relaycommon.RelayInfo{ChannelType: i}
  92. adaptor := relay.GetAdaptor(apiType)
  93. adaptor.Init(meta)
  94. channelId2Models[i] = adaptor.GetModelList()
  95. }
  96. openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
  97. return m.Id
  98. })
  99. }
  100. func ListModels(c *gin.Context) {
  101. userOpenAiModels := make([]dto.OpenAIModels, 0)
  102. modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
  103. if modelLimitEnable {
  104. s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
  105. var tokenModelLimit map[string]bool
  106. if ok {
  107. tokenModelLimit = s.(map[string]bool)
  108. } else {
  109. tokenModelLimit = map[string]bool{}
  110. }
  111. for allowModel, _ := range tokenModelLimit {
  112. if oaiModel, ok := openAIModelsMap[allowModel]; ok {
  113. oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
  114. userOpenAiModels = append(userOpenAiModels, oaiModel)
  115. } else {
  116. userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
  117. Id: allowModel,
  118. Object: "model",
  119. Created: 1626777600,
  120. OwnedBy: "custom",
  121. SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
  122. })
  123. }
  124. }
  125. } else {
  126. userId := c.GetInt("id")
  127. userGroup, err := model.GetUserGroup(userId, false)
  128. if err != nil {
  129. c.JSON(http.StatusOK, gin.H{
  130. "success": false,
  131. "message": "get user group failed",
  132. })
  133. return
  134. }
  135. group := userGroup
  136. tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
  137. if tokenGroup != "" {
  138. group = tokenGroup
  139. }
  140. var models []string
  141. if tokenGroup == "auto" {
  142. for _, autoGroup := range setting.AutoGroups {
  143. groupModels := model.GetGroupEnabledModels(autoGroup)
  144. for _, g := range groupModels {
  145. if !common.StringsContains(models, g) {
  146. models = append(models, g)
  147. }
  148. }
  149. }
  150. } else {
  151. models = model.GetGroupEnabledModels(group)
  152. }
  153. for _, modelName := range models {
  154. if oaiModel, ok := openAIModelsMap[modelName]; ok {
  155. oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
  156. userOpenAiModels = append(userOpenAiModels, oaiModel)
  157. } else {
  158. userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
  159. Id: modelName,
  160. Object: "model",
  161. Created: 1626777600,
  162. OwnedBy: "custom",
  163. SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
  164. })
  165. }
  166. }
  167. }
  168. c.JSON(200, gin.H{
  169. "success": true,
  170. "data": userOpenAiModels,
  171. })
  172. }
  173. func ChannelListModels(c *gin.Context) {
  174. c.JSON(200, gin.H{
  175. "success": true,
  176. "data": openAIModels,
  177. })
  178. }
  179. func DashboardListModels(c *gin.Context) {
  180. c.JSON(200, gin.H{
  181. "success": true,
  182. "data": channelId2Models,
  183. })
  184. }
  185. func EnabledListModels(c *gin.Context) {
  186. c.JSON(200, gin.H{
  187. "success": true,
  188. "data": model.GetEnabledModels(),
  189. })
  190. }
  191. func RetrieveModel(c *gin.Context) {
  192. modelId := c.Param("model")
  193. if aiModel, ok := openAIModelsMap[modelId]; ok {
  194. c.JSON(200, aiModel)
  195. } else {
  196. openAIError := dto.OpenAIError{
  197. Message: fmt.Sprintf("The model '%s' does not exist", modelId),
  198. Type: "invalid_request_error",
  199. Param: "model",
  200. Code: "model_not_found",
  201. }
  202. c.JSON(200, gin.H{
  203. "error": openAIError,
  204. })
  205. }
  206. }