model_meta.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package controller
  2. import (
  3. "encoding/json"
  4. "sort"
  5. "strconv"
  6. "strings"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/model"
  10. "github.com/gin-gonic/gin"
  11. )
  12. // GetAllModelsMeta 获取模型列表(分页)
  13. func GetAllModelsMeta(c *gin.Context) {
  14. pageInfo := common.GetPageQuery(c)
  15. modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  16. if err != nil {
  17. common.ApiError(c, err)
  18. return
  19. }
  20. // 批量填充附加字段,提升列表接口性能
  21. enrichModels(modelsMeta)
  22. var total int64
  23. model.DB.Model(&model.Model{}).Count(&total)
  24. // 统计供应商计数(全部数据,不受分页影响)
  25. vendorCounts, _ := model.GetVendorModelCounts()
  26. pageInfo.SetTotal(int(total))
  27. pageInfo.SetItems(modelsMeta)
  28. common.ApiSuccess(c, gin.H{
  29. "items": modelsMeta,
  30. "total": total,
  31. "page": pageInfo.GetPage(),
  32. "page_size": pageInfo.GetPageSize(),
  33. "vendor_counts": vendorCounts,
  34. })
  35. }
  36. // SearchModelsMeta 搜索模型列表
  37. func SearchModelsMeta(c *gin.Context) {
  38. keyword := c.Query("keyword")
  39. vendor := c.Query("vendor")
  40. pageInfo := common.GetPageQuery(c)
  41. modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  42. if err != nil {
  43. common.ApiError(c, err)
  44. return
  45. }
  46. // 批量填充附加字段,提升列表接口性能
  47. enrichModels(modelsMeta)
  48. pageInfo.SetTotal(int(total))
  49. pageInfo.SetItems(modelsMeta)
  50. common.ApiSuccess(c, pageInfo)
  51. }
  52. // GetModelMeta 根据 ID 获取单条模型信息
  53. func GetModelMeta(c *gin.Context) {
  54. idStr := c.Param("id")
  55. id, err := strconv.Atoi(idStr)
  56. if err != nil {
  57. common.ApiError(c, err)
  58. return
  59. }
  60. var m model.Model
  61. if err := model.DB.First(&m, id).Error; err != nil {
  62. common.ApiError(c, err)
  63. return
  64. }
  65. enrichModels([]*model.Model{&m})
  66. common.ApiSuccess(c, &m)
  67. }
  68. // CreateModelMeta 新建模型
  69. func CreateModelMeta(c *gin.Context) {
  70. var m model.Model
  71. if err := c.ShouldBindJSON(&m); err != nil {
  72. common.ApiError(c, err)
  73. return
  74. }
  75. if m.ModelName == "" {
  76. common.ApiErrorMsg(c, "模型名称不能为空")
  77. return
  78. }
  79. // 名称冲突检查
  80. if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
  81. common.ApiError(c, err)
  82. return
  83. } else if dup {
  84. common.ApiErrorMsg(c, "模型名称已存在")
  85. return
  86. }
  87. if err := m.Insert(); err != nil {
  88. common.ApiError(c, err)
  89. return
  90. }
  91. model.RefreshPricing()
  92. common.ApiSuccess(c, &m)
  93. }
  94. // UpdateModelMeta 更新模型
  95. func UpdateModelMeta(c *gin.Context) {
  96. statusOnly := c.Query("status_only") == "true"
  97. var m model.Model
  98. if err := c.ShouldBindJSON(&m); err != nil {
  99. common.ApiError(c, err)
  100. return
  101. }
  102. if m.Id == 0 {
  103. common.ApiErrorMsg(c, "缺少模型 ID")
  104. return
  105. }
  106. if statusOnly {
  107. // 只更新状态,防止误清空其他字段
  108. if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
  109. common.ApiError(c, err)
  110. return
  111. }
  112. } else {
  113. // 名称冲突检查
  114. if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
  115. common.ApiError(c, err)
  116. return
  117. } else if dup {
  118. common.ApiErrorMsg(c, "模型名称已存在")
  119. return
  120. }
  121. if err := m.Update(); err != nil {
  122. common.ApiError(c, err)
  123. return
  124. }
  125. }
  126. model.RefreshPricing()
  127. common.ApiSuccess(c, &m)
  128. }
  129. // DeleteModelMeta 删除模型
  130. func DeleteModelMeta(c *gin.Context) {
  131. idStr := c.Param("id")
  132. id, err := strconv.Atoi(idStr)
  133. if err != nil {
  134. common.ApiError(c, err)
  135. return
  136. }
  137. if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
  138. common.ApiError(c, err)
  139. return
  140. }
  141. model.RefreshPricing()
  142. common.ApiSuccess(c, nil)
  143. }
  144. // enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
  145. func enrichModels(models []*model.Model) {
  146. if len(models) == 0 {
  147. return
  148. }
  149. // 1) 拆分精确与规则匹配
  150. exactNames := make([]string, 0)
  151. exactIdx := make(map[string][]int) // modelName -> indices in models
  152. ruleIndices := make([]int, 0)
  153. for i, m := range models {
  154. if m == nil {
  155. continue
  156. }
  157. if m.NameRule == model.NameRuleExact {
  158. exactNames = append(exactNames, m.ModelName)
  159. exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
  160. } else {
  161. ruleIndices = append(ruleIndices, i)
  162. }
  163. }
  164. // 2) 批量查询精确模型的绑定渠道
  165. channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
  166. // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
  167. for name, indices := range exactIdx {
  168. chs := channelsByModel[name]
  169. for _, idx := range indices {
  170. mm := models[idx]
  171. if mm.Endpoints == "" {
  172. eps := model.GetModelSupportEndpointTypes(mm.ModelName)
  173. if b, err := json.Marshal(eps); err == nil {
  174. mm.Endpoints = string(b)
  175. }
  176. }
  177. mm.BoundChannels = chs
  178. mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
  179. mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
  180. }
  181. }
  182. if len(ruleIndices) == 0 {
  183. return
  184. }
  185. // 4) 一次性读取定价缓存,内存匹配所有规则模型
  186. pricings := model.GetPricing()
  187. // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
  188. matchedNamesByIdx := make(map[int][]string)
  189. endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
  190. groupSetByIdx := make(map[int]map[string]struct{})
  191. quotaSetByIdx := make(map[int]map[int]struct{})
  192. for _, p := range pricings {
  193. for _, idx := range ruleIndices {
  194. mm := models[idx]
  195. var matched bool
  196. switch mm.NameRule {
  197. case model.NameRulePrefix:
  198. matched = strings.HasPrefix(p.ModelName, mm.ModelName)
  199. case model.NameRuleSuffix:
  200. matched = strings.HasSuffix(p.ModelName, mm.ModelName)
  201. case model.NameRuleContains:
  202. matched = strings.Contains(p.ModelName, mm.ModelName)
  203. }
  204. if !matched {
  205. continue
  206. }
  207. matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
  208. es := endpointSetByIdx[idx]
  209. if es == nil {
  210. es = make(map[constant.EndpointType]struct{})
  211. endpointSetByIdx[idx] = es
  212. }
  213. for _, et := range p.SupportedEndpointTypes {
  214. es[et] = struct{}{}
  215. }
  216. gs := groupSetByIdx[idx]
  217. if gs == nil {
  218. gs = make(map[string]struct{})
  219. groupSetByIdx[idx] = gs
  220. }
  221. for _, g := range p.EnableGroup {
  222. gs[g] = struct{}{}
  223. }
  224. qs := quotaSetByIdx[idx]
  225. if qs == nil {
  226. qs = make(map[int]struct{})
  227. quotaSetByIdx[idx] = qs
  228. }
  229. qs[p.QuotaType] = struct{}{}
  230. }
  231. }
  232. // 5) 汇总所有匹配到的模型名称,批量查询一次渠道
  233. allMatchedSet := make(map[string]struct{})
  234. for _, names := range matchedNamesByIdx {
  235. for _, n := range names {
  236. allMatchedSet[n] = struct{}{}
  237. }
  238. }
  239. allMatched := make([]string, 0, len(allMatchedSet))
  240. for n := range allMatchedSet {
  241. allMatched = append(allMatched, n)
  242. }
  243. matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
  244. // 6) 回填每个规则模型的并集信息
  245. for _, idx := range ruleIndices {
  246. mm := models[idx]
  247. // 端点并集 -> 序列化
  248. if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
  249. eps := make([]constant.EndpointType, 0, len(es))
  250. for et := range es {
  251. eps = append(eps, et)
  252. }
  253. if b, err := json.Marshal(eps); err == nil {
  254. mm.Endpoints = string(b)
  255. }
  256. }
  257. // 分组并集
  258. if gs, ok := groupSetByIdx[idx]; ok {
  259. groups := make([]string, 0, len(gs))
  260. for g := range gs {
  261. groups = append(groups, g)
  262. }
  263. mm.EnableGroups = groups
  264. }
  265. // 配额类型集合(保持去重并排序)
  266. if qs, ok := quotaSetByIdx[idx]; ok {
  267. arr := make([]int, 0, len(qs))
  268. for k := range qs {
  269. arr = append(arr, k)
  270. }
  271. sort.Ints(arr)
  272. mm.QuotaTypes = arr
  273. }
  274. // 渠道并集
  275. names := matchedNamesByIdx[idx]
  276. channelSet := make(map[string]model.BoundChannel)
  277. for _, n := range names {
  278. for _, ch := range matchedChannelsByModel[n] {
  279. key := ch.Name + "_" + strconv.Itoa(ch.Type)
  280. channelSet[key] = ch
  281. }
  282. }
  283. if len(channelSet) > 0 {
  284. chs := make([]model.BoundChannel, 0, len(channelSet))
  285. for _, ch := range channelSet {
  286. chs = append(chs, ch)
  287. }
  288. mm.BoundChannels = chs
  289. }
  290. // 匹配信息
  291. mm.MatchedModels = names
  292. mm.MatchedCount = len(names)
  293. }
  294. }