model.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/dto"
  9. "one-api/model"
  10. "one-api/relay"
  11. "one-api/relay/channel/ai360"
  12. "one-api/relay/channel/lingyiwanwu"
  13. "one-api/relay/channel/moonshot"
  14. relaycommon "one-api/relay/common"
  15. relayconstant "one-api/relay/constant"
  16. )
  17. // https://platform.openai.com/docs/api-reference/models/list
  18. type OpenAIModelPermission struct {
  19. Id string `json:"id"`
  20. Object string `json:"object"`
  21. Created int `json:"created"`
  22. AllowCreateEngine bool `json:"allow_create_engine"`
  23. AllowSampling bool `json:"allow_sampling"`
  24. AllowLogprobs bool `json:"allow_logprobs"`
  25. AllowSearchIndices bool `json:"allow_search_indices"`
  26. AllowView bool `json:"allow_view"`
  27. AllowFineTuning bool `json:"allow_fine_tuning"`
  28. Organization string `json:"organization"`
  29. Group *string `json:"group"`
  30. IsBlocking bool `json:"is_blocking"`
  31. }
  32. type OpenAIModels struct {
  33. Id string `json:"id"`
  34. Object string `json:"object"`
  35. Created int `json:"created"`
  36. OwnedBy string `json:"owned_by"`
  37. Permission []OpenAIModelPermission `json:"permission"`
  38. Root string `json:"root"`
  39. Parent *string `json:"parent"`
  40. }
  41. var openAIModels []OpenAIModels
  42. var openAIModelsMap map[string]OpenAIModels
  43. var channelId2Models map[int][]string
  44. func getPermission() []OpenAIModelPermission {
  45. var permission []OpenAIModelPermission
  46. permission = append(permission, OpenAIModelPermission{
  47. Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
  48. Object: "model_permission",
  49. Created: 1626777600,
  50. AllowCreateEngine: true,
  51. AllowSampling: true,
  52. AllowLogprobs: true,
  53. AllowSearchIndices: false,
  54. AllowView: true,
  55. AllowFineTuning: false,
  56. Organization: "*",
  57. Group: nil,
  58. IsBlocking: false,
  59. })
  60. return permission
  61. }
  62. func init() {
  63. // https://platform.openai.com/docs/models/model-endpoint-compatibility
  64. permission := getPermission()
  65. for i := 0; i < relayconstant.APITypeDummy; i++ {
  66. if i == relayconstant.APITypeAIProxyLibrary {
  67. continue
  68. }
  69. adaptor := relay.GetAdaptor(i)
  70. channelName := adaptor.GetChannelName()
  71. modelNames := adaptor.GetModelList()
  72. for _, modelName := range modelNames {
  73. openAIModels = append(openAIModels, OpenAIModels{
  74. Id: modelName,
  75. Object: "model",
  76. Created: 1626777600,
  77. OwnedBy: channelName,
  78. Permission: permission,
  79. Root: modelName,
  80. Parent: nil,
  81. })
  82. }
  83. }
  84. for _, modelName := range ai360.ModelList {
  85. openAIModels = append(openAIModels, OpenAIModels{
  86. Id: modelName,
  87. Object: "model",
  88. Created: 1626777600,
  89. OwnedBy: ai360.ChannelName,
  90. Permission: permission,
  91. Root: modelName,
  92. Parent: nil,
  93. })
  94. }
  95. for _, modelName := range moonshot.ModelList {
  96. openAIModels = append(openAIModels, OpenAIModels{
  97. Id: modelName,
  98. Object: "model",
  99. Created: 1626777600,
  100. OwnedBy: "moonshot",
  101. Permission: permission,
  102. Root: modelName,
  103. Parent: nil,
  104. })
  105. }
  106. for _, modelName := range lingyiwanwu.ModelList {
  107. openAIModels = append(openAIModels, OpenAIModels{
  108. Id: modelName,
  109. Object: "model",
  110. Created: 1626777600,
  111. OwnedBy: "lingyiwanwu",
  112. Permission: permission,
  113. Root: modelName,
  114. Parent: nil,
  115. })
  116. }
  117. for modelName, _ := range constant.MidjourneyModel2Action {
  118. openAIModels = append(openAIModels, OpenAIModels{
  119. Id: modelName,
  120. Object: "model",
  121. Created: 1626777600,
  122. OwnedBy: "midjourney",
  123. Permission: permission,
  124. Root: modelName,
  125. Parent: nil,
  126. })
  127. }
  128. openAIModelsMap = make(map[string]OpenAIModels)
  129. for _, model := range openAIModels {
  130. openAIModelsMap[model.Id] = model
  131. }
  132. channelId2Models = make(map[int][]string)
  133. for i := 1; i <= common.ChannelTypeDummy; i++ {
  134. apiType := relayconstant.ChannelType2APIType(i)
  135. if apiType == -1 || apiType == relayconstant.APITypeAIProxyLibrary {
  136. continue
  137. }
  138. meta := &relaycommon.RelayInfo{ChannelType: i}
  139. adaptor := relay.GetAdaptor(apiType)
  140. adaptor.Init(meta, dto.GeneralOpenAIRequest{})
  141. channelId2Models[i] = adaptor.GetModelList()
  142. }
  143. }
  144. func ListModels(c *gin.Context) {
  145. userId := c.GetInt("id")
  146. user, err := model.GetUserById(userId, true)
  147. if err != nil {
  148. c.JSON(http.StatusOK, gin.H{
  149. "success": false,
  150. "message": err.Error(),
  151. })
  152. return
  153. }
  154. models := model.GetGroupModels(user.Group)
  155. userOpenAiModels := make([]OpenAIModels, 0)
  156. permission := getPermission()
  157. for _, s := range models {
  158. if _, ok := openAIModelsMap[s]; ok {
  159. userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
  160. } else {
  161. userOpenAiModels = append(userOpenAiModels, OpenAIModels{
  162. Id: s,
  163. Object: "model",
  164. Created: 1626777600,
  165. OwnedBy: "openai",
  166. Permission: permission,
  167. Root: s,
  168. Parent: nil,
  169. })
  170. }
  171. }
  172. c.JSON(200, gin.H{
  173. "success": true,
  174. "data": userOpenAiModels,
  175. })
  176. }
  177. func ChannelListModels(c *gin.Context) {
  178. c.JSON(200, gin.H{
  179. "success": true,
  180. "data": openAIModels,
  181. })
  182. }
  183. func DashboardListModels(c *gin.Context) {
  184. c.JSON(200, gin.H{
  185. "success": true,
  186. "data": channelId2Models,
  187. })
  188. }
  189. func RetrieveModel(c *gin.Context) {
  190. modelId := c.Param("model")
  191. if model, ok := openAIModelsMap[modelId]; ok {
  192. c.JSON(200, model)
  193. } else {
  194. openAIError := dto.OpenAIError{
  195. Message: fmt.Sprintf("The model '%s' does not exist", modelId),
  196. Type: "invalid_request_error",
  197. Param: "model",
  198. Code: "model_not_found",
  199. }
  200. c.JSON(200, gin.H{
  201. "error": openAIError,
  202. })
  203. }
  204. }