model.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/dto"
  9. "one-api/model"
  10. "one-api/relay"
  11. "one-api/relay/channel/ai360"
  12. "one-api/relay/channel/lingyiwanwu"
  13. "one-api/relay/channel/moonshot"
  14. relaycommon "one-api/relay/common"
  15. relayconstant "one-api/relay/constant"
  16. )
  17. // https://platform.openai.com/docs/api-reference/models/list
  18. var openAIModels []dto.OpenAIModels
  19. var openAIModelsMap map[string]dto.OpenAIModels
  20. var channelId2Models map[int][]string
  21. func getPermission() []dto.OpenAIModelPermission {
  22. var permission []dto.OpenAIModelPermission
  23. permission = append(permission, dto.OpenAIModelPermission{
  24. Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
  25. Object: "model_permission",
  26. Created: 1626777600,
  27. AllowCreateEngine: true,
  28. AllowSampling: true,
  29. AllowLogprobs: true,
  30. AllowSearchIndices: false,
  31. AllowView: true,
  32. AllowFineTuning: false,
  33. Organization: "*",
  34. Group: nil,
  35. IsBlocking: false,
  36. })
  37. return permission
  38. }
  39. func init() {
  40. // https://platform.openai.com/docs/models/model-endpoint-compatibility
  41. permission := getPermission()
  42. for i := 0; i < relayconstant.APITypeDummy; i++ {
  43. if i == relayconstant.APITypeAIProxyLibrary {
  44. continue
  45. }
  46. adaptor := relay.GetAdaptor(i)
  47. channelName := adaptor.GetChannelName()
  48. modelNames := adaptor.GetModelList()
  49. for _, modelName := range modelNames {
  50. openAIModels = append(openAIModels, dto.OpenAIModels{
  51. Id: modelName,
  52. Object: "model",
  53. Created: 1626777600,
  54. OwnedBy: channelName,
  55. Permission: permission,
  56. Root: modelName,
  57. Parent: nil,
  58. })
  59. }
  60. }
  61. for _, modelName := range ai360.ModelList {
  62. openAIModels = append(openAIModels, dto.OpenAIModels{
  63. Id: modelName,
  64. Object: "model",
  65. Created: 1626777600,
  66. OwnedBy: ai360.ChannelName,
  67. Permission: permission,
  68. Root: modelName,
  69. Parent: nil,
  70. })
  71. }
  72. for _, modelName := range moonshot.ModelList {
  73. openAIModels = append(openAIModels, dto.OpenAIModels{
  74. Id: modelName,
  75. Object: "model",
  76. Created: 1626777600,
  77. OwnedBy: "moonshot",
  78. Permission: permission,
  79. Root: modelName,
  80. Parent: nil,
  81. })
  82. }
  83. for _, modelName := range lingyiwanwu.ModelList {
  84. openAIModels = append(openAIModels, dto.OpenAIModels{
  85. Id: modelName,
  86. Object: "model",
  87. Created: 1626777600,
  88. OwnedBy: "lingyiwanwu",
  89. Permission: permission,
  90. Root: modelName,
  91. Parent: nil,
  92. })
  93. }
  94. for modelName, _ := range constant.MidjourneyModel2Action {
  95. openAIModels = append(openAIModels, dto.OpenAIModels{
  96. Id: modelName,
  97. Object: "model",
  98. Created: 1626777600,
  99. OwnedBy: "midjourney",
  100. Permission: permission,
  101. Root: modelName,
  102. Parent: nil,
  103. })
  104. }
  105. openAIModelsMap = make(map[string]dto.OpenAIModels)
  106. for _, aiModel := range openAIModels {
  107. openAIModelsMap[aiModel.Id] = aiModel
  108. }
  109. channelId2Models = make(map[int][]string)
  110. for i := 1; i <= common.ChannelTypeDummy; i++ {
  111. apiType, success := relayconstant.ChannelType2APIType(i)
  112. if !success || apiType == relayconstant.APITypeAIProxyLibrary {
  113. continue
  114. }
  115. meta := &relaycommon.RelayInfo{ChannelType: i}
  116. adaptor := relay.GetAdaptor(apiType)
  117. adaptor.Init(meta, dto.GeneralOpenAIRequest{})
  118. channelId2Models[i] = adaptor.GetModelList()
  119. }
  120. }
  121. func ListModels(c *gin.Context) {
  122. userId := c.GetInt("id")
  123. user, err := model.GetUserById(userId, true)
  124. if err != nil {
  125. c.JSON(http.StatusOK, gin.H{
  126. "success": false,
  127. "message": err.Error(),
  128. })
  129. return
  130. }
  131. models := model.GetGroupModels(user.Group)
  132. userOpenAiModels := make([]dto.OpenAIModels, 0)
  133. permission := getPermission()
  134. for _, s := range models {
  135. if _, ok := openAIModelsMap[s]; ok {
  136. userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
  137. } else {
  138. userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
  139. Id: s,
  140. Object: "model",
  141. Created: 1626777600,
  142. OwnedBy: "custom",
  143. Permission: permission,
  144. Root: s,
  145. Parent: nil,
  146. })
  147. }
  148. }
  149. c.JSON(200, gin.H{
  150. "success": true,
  151. "data": userOpenAiModels,
  152. })
  153. }
  154. func ChannelListModels(c *gin.Context) {
  155. c.JSON(200, gin.H{
  156. "success": true,
  157. "data": openAIModels,
  158. })
  159. }
  160. func DashboardListModels(c *gin.Context) {
  161. c.JSON(200, gin.H{
  162. "success": true,
  163. "data": channelId2Models,
  164. })
  165. }
  166. func RetrieveModel(c *gin.Context) {
  167. modelId := c.Param("model")
  168. if aiModel, ok := openAIModelsMap[modelId]; ok {
  169. c.JSON(200, aiModel)
  170. } else {
  171. openAIError := dto.OpenAIError{
  172. Message: fmt.Sprintf("The model '%s' does not exist", modelId),
  173. Type: "invalid_request_error",
  174. Param: "model",
  175. Code: "model_not_found",
  176. }
  177. c.JSON(200, gin.H{
  178. "error": openAIError,
  179. })
  180. }
  181. }
  182. func GetPricing(c *gin.Context) {
  183. userId := c.GetInt("id")
  184. group, err := model.CacheGetUserGroup(userId)
  185. groupRatio := common.GetGroupRatio("default")
  186. if err != nil {
  187. groupRatio = common.GetGroupRatio(group)
  188. }
  189. pricing := model.GetPricing(group)
  190. c.JSON(200, gin.H{
  191. "success": true,
  192. "data": pricing,
  193. "group_ratio": groupRatio,
  194. })
  195. }