model.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package controller
  2. import (
  3. "net/http"
  4. "slices"
  5. "sort"
  6. "strconv"
  7. "github.com/bytedance/sonic"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/labring/aiproxy/core/middleware"
  11. "github.com/labring/aiproxy/core/model"
  12. "github.com/labring/aiproxy/core/relay/adaptors"
  13. log "github.com/sirupsen/logrus"
  14. )
  15. // https://platform.openai.com/docs/api-reference/models/list
  16. type OpenAIModelPermission struct {
  17. Group *string `json:"group"`
  18. ID string `json:"id"`
  19. Object string `json:"object"`
  20. Organization string `json:"organization"`
  21. Created int `json:"created"`
  22. AllowCreateEngine bool `json:"allow_create_engine"`
  23. AllowSampling bool `json:"allow_sampling"`
  24. AllowLogprobs bool `json:"allow_logprobs"`
  25. AllowSearchIndices bool `json:"allow_search_indices"`
  26. AllowView bool `json:"allow_view"`
  27. AllowFineTuning bool `json:"allow_fine_tuning"`
  28. IsBlocking bool `json:"is_blocking"`
  29. }
  30. type OpenAIModels struct {
  31. Parent *string `json:"parent"`
  32. ID string `json:"id"`
  33. Object string `json:"object"`
  34. OwnedBy string `json:"owned_by"`
  35. Root string `json:"root"`
  36. Permission []OpenAIModelPermission `json:"permission"`
  37. Created int `json:"created"`
  38. }
  39. type BuiltinModelConfig model.ModelConfig
  40. func (c *BuiltinModelConfig) MarshalJSON() ([]byte, error) {
  41. type Alias BuiltinModelConfig
  42. return sonic.Marshal(&struct {
  43. *Alias
  44. CreatedAt int64 `json:"created_at,omitempty"`
  45. UpdatedAt int64 `json:"updated_at,omitempty"`
  46. }{
  47. Alias: (*Alias)(c),
  48. })
  49. }
  50. func SortBuiltinModelConfigsFunc(i, j BuiltinModelConfig) int {
  51. return model.SortModelConfigsFunc((model.ModelConfig)(i), (model.ModelConfig)(j))
  52. }
  53. var (
  54. builtinModels []BuiltinModelConfig
  55. builtinModelsMap map[string]*OpenAIModels
  56. builtinChannelType2Models map[model.ChannelType][]BuiltinModelConfig
  57. )
  58. var permission = []OpenAIModelPermission{
  59. {
  60. ID: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
  61. Object: "model_permission",
  62. Created: 1626777600,
  63. AllowCreateEngine: true,
  64. AllowSampling: true,
  65. AllowLogprobs: true,
  66. AllowSearchIndices: false,
  67. AllowView: true,
  68. AllowFineTuning: false,
  69. Organization: "*",
  70. Group: nil,
  71. IsBlocking: false,
  72. },
  73. }
  74. func init() {
  75. builtinChannelType2Models = make(map[model.ChannelType][]BuiltinModelConfig)
  76. builtinModelsMap = make(map[string]*OpenAIModels)
  77. // https://platform.openai.com/docs/models/model-endpoint-compatibility
  78. for i, adaptor := range adaptors.ChannelAdaptor {
  79. modelNames := adaptor.Metadata().Models
  80. builtinChannelType2Models[i] = make([]BuiltinModelConfig, len(modelNames))
  81. for idx, _model := range modelNames {
  82. if _model.Owner == "" {
  83. _model.Owner = model.ModelOwner(i.String())
  84. }
  85. if v, ok := builtinModelsMap[_model.Model]; !ok {
  86. builtinModelsMap[_model.Model] = &OpenAIModels{
  87. ID: _model.Model,
  88. Object: "model",
  89. Created: 1626777600,
  90. OwnedBy: string(_model.Owner),
  91. Permission: permission,
  92. Root: _model.Model,
  93. Parent: nil,
  94. }
  95. builtinModels = append(builtinModels, (BuiltinModelConfig)(_model))
  96. } else if v.OwnedBy != string(_model.Owner) {
  97. log.Fatalf("model %s owner mismatch, expect %s, actual %s", _model.Model, string(_model.Owner), v.OwnedBy)
  98. }
  99. builtinChannelType2Models[i][idx] = (BuiltinModelConfig)(_model)
  100. }
  101. }
  102. for _, models := range builtinChannelType2Models {
  103. sort.Slice(models, func(i, j int) bool {
  104. return models[i].Model < models[j].Model
  105. })
  106. slices.SortStableFunc(models, SortBuiltinModelConfigsFunc)
  107. }
  108. slices.SortStableFunc(builtinModels, SortBuiltinModelConfigsFunc)
  109. }
  110. // BuiltinModels godoc
  111. //
  112. // @Summary Get builtin models
  113. // @Description Returns a list of builtin models
  114. // @Tags model
  115. // @Produce json
  116. // @Security ApiKeyAuth
  117. // @Success 200 {object} middleware.APIResponse{data=[]BuiltinModelConfig}
  118. // @Router /api/models/builtin [get]
  119. func BuiltinModels(c *gin.Context) {
  120. middleware.SuccessResponse(c, builtinModels)
  121. }
  122. // ChannelBuiltinModels godoc
  123. //
  124. // @Summary Get channel builtin models
  125. // @Description Returns a list of channel builtin models
  126. // @Tags model
  127. // @Produce json
  128. // @Security ApiKeyAuth
  129. // @Success 200 {object} middleware.APIResponse{data=map[int][]BuiltinModelConfig}
  130. // @Router /api/models/builtin/channel [get]
  131. func ChannelBuiltinModels(c *gin.Context) {
  132. middleware.SuccessResponse(c, builtinChannelType2Models)
  133. }
  134. // ChannelBuiltinModelsByType godoc
  135. //
  136. // @Summary Get channel builtin models by type
  137. // @Description Returns a list of channel builtin models by type
  138. // @Tags model
  139. // @Produce json
  140. // @Security ApiKeyAuth
  141. // @Param type path model.ChannelType true "Channel type"
  142. // @Success 200 {object} middleware.APIResponse{data=[]BuiltinModelConfig}
  143. // @Router /api/models/builtin/channel/{type} [get]
  144. func ChannelBuiltinModelsByType(c *gin.Context) {
  145. channelType := c.Param("type")
  146. if channelType == "" {
  147. middleware.ErrorResponse(c, http.StatusBadRequest, "type is required")
  148. return
  149. }
  150. channelTypeInt, err := strconv.Atoi(channelType)
  151. if err != nil {
  152. middleware.ErrorResponse(c, http.StatusBadRequest, "invalid type")
  153. return
  154. }
  155. middleware.SuccessResponse(c, builtinChannelType2Models[model.ChannelType(channelTypeInt)])
  156. }
  157. // ChannelDefaultModelsAndMapping godoc
  158. //
  159. // @Summary Get channel default models and mapping
  160. // @Description Returns a list of channel default models and mapping
  161. // @Tags model
  162. // @Produce json
  163. // @Security ApiKeyAuth
  164. // @Success 200 {object} middleware.APIResponse{data=map[string]any{models=[]string,mapping=map[string]string}}
  165. // @Router /api/models/default [get]
  166. func ChannelDefaultModelsAndMapping(c *gin.Context) {
  167. middleware.SuccessResponse(c, gin.H{
  168. "models": config.GetDefaultChannelModels(),
  169. "mapping": config.GetDefaultChannelModelMapping(),
  170. })
  171. }
  172. // ChannelDefaultModelsAndMappingByType godoc
  173. //
  174. // @Summary Get channel default models and mapping by type
  175. // @Description Returns a list of channel default models and mapping by type
  176. // @Tags model
  177. // @Produce json
  178. // @Security ApiKeyAuth
  179. // @Param type path string true "Channel type"
  180. // @Success 200 {object} middleware.APIResponse{data=map[string]any{models=[]string,mapping=map[string]string}}
  181. // @Router /api/models/default/{type} [get]
  182. func ChannelDefaultModelsAndMappingByType(c *gin.Context) {
  183. channelType := c.Param("type")
  184. if channelType == "" {
  185. middleware.ErrorResponse(c, http.StatusBadRequest, "type is required")
  186. return
  187. }
  188. channelTypeInt, err := strconv.Atoi(channelType)
  189. if err != nil {
  190. middleware.ErrorResponse(c, http.StatusBadRequest, "invalid type")
  191. return
  192. }
  193. middleware.SuccessResponse(c, gin.H{
  194. "models": config.GetDefaultChannelModels()[channelTypeInt],
  195. "mapping": config.GetDefaultChannelModelMapping()[channelTypeInt],
  196. })
  197. }
  198. // EnabledModels godoc
  199. //
  200. // @Summary Get enabled models
  201. // @Description Returns a list of enabled models
  202. // @Tags model
  203. // @Produce json
  204. // @Security ApiKeyAuth
  205. // @Success 200 {object} middleware.APIResponse{data=map[string][]model.ModelConfig}
  206. // @Router /api/models/enabled [get]
  207. func EnabledModels(c *gin.Context) {
  208. middleware.SuccessResponse(c, model.LoadModelCaches().EnabledModelConfigsBySet)
  209. }
  210. // EnabledModelsSet godoc
  211. //
  212. // @Summary Get enabled models by set
  213. // @Description Returns a list of enabled models by set
  214. // @Tags model
  215. // @Produce json
  216. // @Security ApiKeyAuth
  217. // @Param set path string true "Models set"
  218. // @Success 200 {object} middleware.APIResponse{data=[]model.ModelConfig}
  219. // @Router /api/models/enabled/{set} [get]
  220. func EnabledModelsSet(c *gin.Context) {
  221. set := c.Param("set")
  222. if set == "" {
  223. middleware.ErrorResponse(c, http.StatusBadRequest, "set is required")
  224. return
  225. }
  226. middleware.SuccessResponse(c, model.LoadModelCaches().EnabledModelConfigsBySet[set])
  227. }
  228. type EnabledModelChannel struct {
  229. ID int `json:"id"`
  230. Type model.ChannelType `json:"type"`
  231. Name string `json:"name"`
  232. }
  233. func newEnabledModelChannel(ch *model.Channel) EnabledModelChannel {
  234. return EnabledModelChannel{
  235. ID: ch.ID,
  236. Type: ch.Type,
  237. Name: ch.Name,
  238. }
  239. }
  240. // EnabledModelSets godoc
  241. //
  242. // @Summary Get enabled models and channels sets
  243. // @Description Returns a list of enabled models and channels sets
  244. // @Tags model
  245. // @Produce json
  246. // @Security ApiKeyAuth
  247. // @Success 200 {object} middleware.APIResponse{data=map[string]map[string][]EnabledModelChannel}
  248. // @Router /api/models/sets [get]
  249. func EnabledModelSets(c *gin.Context) {
  250. raw := model.LoadModelCaches().EnabledModel2ChannelsBySet
  251. result := make(map[string]map[string][]EnabledModelChannel)
  252. // First iterate through sets to get all models
  253. for _, modelChannels := range raw {
  254. for model := range modelChannels {
  255. if _, exists := result[model]; !exists {
  256. result[model] = make(map[string][]EnabledModelChannel)
  257. }
  258. }
  259. }
  260. // Then populate the channels for each model and set
  261. for set, modelChannels := range raw {
  262. for model, channels := range modelChannels {
  263. chs := make([]EnabledModelChannel, len(channels))
  264. for i, channel := range channels {
  265. chs[i] = newEnabledModelChannel(channel)
  266. }
  267. result[model][set] = chs
  268. }
  269. }
  270. middleware.SuccessResponse(c, result)
  271. }