model.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/samber/lo"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/dto"
  9. "one-api/model"
  10. "one-api/relay"
  11. "one-api/relay/channel/ai360"
  12. "one-api/relay/channel/lingyiwanwu"
  13. "one-api/relay/channel/minimax"
  14. "one-api/relay/channel/moonshot"
  15. relaycommon "one-api/relay/common"
  16. relayconstant "one-api/relay/constant"
  17. "one-api/setting"
  18. "github.com/gin-gonic/gin"
  19. )
  20. // https://platform.openai.com/docs/api-reference/models/list
  21. var openAIModels []dto.OpenAIModels
  22. var openAIModelsMap map[string]dto.OpenAIModels
  23. var channelId2Models map[int][]string
  24. func getPermission() []dto.OpenAIModelPermission {
  25. var permission []dto.OpenAIModelPermission
  26. permission = append(permission, dto.OpenAIModelPermission{
  27. Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
  28. Object: "model_permission",
  29. Created: 1626777600,
  30. AllowCreateEngine: true,
  31. AllowSampling: true,
  32. AllowLogprobs: true,
  33. AllowSearchIndices: false,
  34. AllowView: true,
  35. AllowFineTuning: false,
  36. Organization: "*",
  37. Group: nil,
  38. IsBlocking: false,
  39. })
  40. return permission
  41. }
  42. func init() {
  43. // https://platform.openai.com/docs/models/model-endpoint-compatibility
  44. permission := getPermission()
  45. for i := 0; i < relayconstant.APITypeDummy; i++ {
  46. if i == relayconstant.APITypeAIProxyLibrary {
  47. continue
  48. }
  49. adaptor := relay.GetAdaptor(i)
  50. channelName := adaptor.GetChannelName()
  51. modelNames := adaptor.GetModelList()
  52. for _, modelName := range modelNames {
  53. openAIModels = append(openAIModels, dto.OpenAIModels{
  54. Id: modelName,
  55. Object: "model",
  56. Created: 1626777600,
  57. OwnedBy: channelName,
  58. Permission: permission,
  59. Root: modelName,
  60. Parent: nil,
  61. })
  62. }
  63. }
  64. for _, modelName := range ai360.ModelList {
  65. openAIModels = append(openAIModels, dto.OpenAIModels{
  66. Id: modelName,
  67. Object: "model",
  68. Created: 1626777600,
  69. OwnedBy: ai360.ChannelName,
  70. Permission: permission,
  71. Root: modelName,
  72. Parent: nil,
  73. })
  74. }
  75. for _, modelName := range moonshot.ModelList {
  76. openAIModels = append(openAIModels, dto.OpenAIModels{
  77. Id: modelName,
  78. Object: "model",
  79. Created: 1626777600,
  80. OwnedBy: moonshot.ChannelName,
  81. Permission: permission,
  82. Root: modelName,
  83. Parent: nil,
  84. })
  85. }
  86. for _, modelName := range lingyiwanwu.ModelList {
  87. openAIModels = append(openAIModels, dto.OpenAIModels{
  88. Id: modelName,
  89. Object: "model",
  90. Created: 1626777600,
  91. OwnedBy: lingyiwanwu.ChannelName,
  92. Permission: permission,
  93. Root: modelName,
  94. Parent: nil,
  95. })
  96. }
  97. for _, modelName := range minimax.ModelList {
  98. openAIModels = append(openAIModels, dto.OpenAIModels{
  99. Id: modelName,
  100. Object: "model",
  101. Created: 1626777600,
  102. OwnedBy: minimax.ChannelName,
  103. Permission: permission,
  104. Root: modelName,
  105. Parent: nil,
  106. })
  107. }
  108. for modelName, _ := range constant.MidjourneyModel2Action {
  109. openAIModels = append(openAIModels, dto.OpenAIModels{
  110. Id: modelName,
  111. Object: "model",
  112. Created: 1626777600,
  113. OwnedBy: "midjourney",
  114. Permission: permission,
  115. Root: modelName,
  116. Parent: nil,
  117. })
  118. }
  119. openAIModelsMap = make(map[string]dto.OpenAIModels)
  120. for _, aiModel := range openAIModels {
  121. openAIModelsMap[aiModel.Id] = aiModel
  122. }
  123. channelId2Models = make(map[int][]string)
  124. for i := 1; i <= common.ChannelTypeDummy; i++ {
  125. apiType, success := relayconstant.ChannelType2APIType(i)
  126. if !success || apiType == relayconstant.APITypeAIProxyLibrary {
  127. continue
  128. }
  129. meta := &relaycommon.RelayInfo{ChannelType: i}
  130. adaptor := relay.GetAdaptor(apiType)
  131. adaptor.Init(meta)
  132. channelId2Models[i] = adaptor.GetModelList()
  133. }
  134. openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
  135. return m.Id
  136. })
  137. }
  138. func ListModels(c *gin.Context) {
  139. userOpenAiModels := make([]dto.OpenAIModels, 0)
  140. permission := getPermission()
  141. modelLimitEnable := c.GetBool("token_model_limit_enabled")
  142. if modelLimitEnable {
  143. s, ok := c.Get("token_model_limit")
  144. var tokenModelLimit map[string]bool
  145. if ok {
  146. tokenModelLimit = s.(map[string]bool)
  147. } else {
  148. tokenModelLimit = map[string]bool{}
  149. }
  150. for allowModel, _ := range tokenModelLimit {
  151. if _, ok := openAIModelsMap[allowModel]; ok {
  152. userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
  153. } else {
  154. userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
  155. Id: allowModel,
  156. Object: "model",
  157. Created: 1626777600,
  158. OwnedBy: "custom",
  159. Permission: permission,
  160. Root: allowModel,
  161. Parent: nil,
  162. })
  163. }
  164. }
  165. } else {
  166. userId := c.GetInt("id")
  167. userGroup, err := model.GetUserGroup(userId, true)
  168. if err != nil {
  169. c.JSON(http.StatusOK, gin.H{
  170. "success": false,
  171. "message": "get user group failed",
  172. })
  173. return
  174. }
  175. group := userGroup
  176. tokenGroup := c.GetString("token_group")
  177. if tokenGroup != "" {
  178. group = tokenGroup
  179. }
  180. var models []string
  181. if tokenGroup == "auto" {
  182. for _, autoGroup := range setting.AutoGroups {
  183. groupModels := model.GetGroupModels(autoGroup)
  184. for _, g := range groupModels {
  185. if !common.StringsContains(models, g) {
  186. models = append(models, g)
  187. }
  188. }
  189. }
  190. } else {
  191. models = model.GetGroupModels(group)
  192. }
  193. for _, s := range models {
  194. if _, ok := openAIModelsMap[s]; ok {
  195. userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
  196. } else {
  197. userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
  198. Id: s,
  199. Object: "model",
  200. Created: 1626777600,
  201. OwnedBy: "custom",
  202. Permission: permission,
  203. Root: s,
  204. Parent: nil,
  205. })
  206. }
  207. }
  208. }
  209. c.JSON(200, gin.H{
  210. "success": true,
  211. "data": userOpenAiModels,
  212. })
  213. }
  214. func ChannelListModels(c *gin.Context) {
  215. c.JSON(200, gin.H{
  216. "success": true,
  217. "data": openAIModels,
  218. })
  219. }
  220. func DashboardListModels(c *gin.Context) {
  221. c.JSON(200, gin.H{
  222. "success": true,
  223. "data": channelId2Models,
  224. })
  225. }
  226. func EnabledListModels(c *gin.Context) {
  227. c.JSON(200, gin.H{
  228. "success": true,
  229. "data": model.GetEnabledModels(),
  230. })
  231. }
  232. func RetrieveModel(c *gin.Context) {
  233. modelId := c.Param("model")
  234. if aiModel, ok := openAIModelsMap[modelId]; ok {
  235. c.JSON(200, aiModel)
  236. } else {
  237. openAIError := dto.OpenAIError{
  238. Message: fmt.Sprintf("The model '%s' does not exist", modelId),
  239. Type: "invalid_request_error",
  240. Param: "model",
  241. Code: "model_not_found",
  242. }
  243. c.JSON(200, gin.H{
  244. "error": openAIError,
  245. })
  246. }
  247. }