model_meta.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. package controller
  2. import (
  3. "encoding/json"
  4. "sort"
  5. "strconv"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/model"
  10. "github.com/gin-gonic/gin"
  11. )
  12. // parseModelID 解析并验证模型ID
  13. func parseModelID(c *gin.Context) (int, error) {
  14. idStr := c.Param("id")
  15. return strconv.Atoi(idStr)
  16. }
  17. // checkModelNameDuplicate 检查模型名称是否重复,返回错误消息
  18. func checkModelNameDuplicate(id int, name string) string {
  19. if name == "" {
  20. return "模型名称不能为空"
  21. }
  22. if dup, err := model.IsModelNameDuplicated(id, name); err != nil {
  23. return err.Error()
  24. } else if dup {
  25. return "模型名称已存在"
  26. }
  27. return ""
  28. }
  29. // GetAllModelsMeta 获取模型列表(分页)
  30. func GetAllModelsMeta(c *gin.Context) {
  31. pageInfo := common.GetPageQuery(c)
  32. status := c.Query("status")
  33. syncOfficial := c.Query("sync_official")
  34. modelsMeta, total, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), status, syncOfficial)
  35. if err != nil {
  36. common.ApiError(c, err)
  37. return
  38. }
  39. respondWithModels(c, modelsMeta, total, pageInfo, status, syncOfficial)
  40. }
  41. // SearchModelsMeta 搜索模型列表
  42. func SearchModelsMeta(c *gin.Context) {
  43. keyword := c.Query("keyword")
  44. vendor := c.Query("vendor")
  45. status := c.Query("status")
  46. syncOfficial := c.Query("sync_official")
  47. pageInfo := common.GetPageQuery(c)
  48. modelsMeta, total, err := model.SearchModels(keyword, vendor, status, syncOfficial, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  49. if err != nil {
  50. common.ApiError(c, err)
  51. return
  52. }
  53. respondWithModels(c, modelsMeta, total, pageInfo, status, syncOfficial)
  54. }
  55. // respondWithModels 统一处理模型列表响应
  56. func respondWithModels(c *gin.Context, models []*model.Model, total int64, pageInfo *common.PageInfo, status, syncOfficial string) {
  57. enrichModels(models)
  58. vendorCounts, _ := model.GetVendorModelCounts(status, syncOfficial)
  59. pageInfo.SetTotal(int(total))
  60. pageInfo.SetItems(models)
  61. common.ApiSuccess(c, gin.H{
  62. "items": models,
  63. "total": total,
  64. "page": pageInfo.GetPage(),
  65. "page_size": pageInfo.GetPageSize(),
  66. "vendor_counts": vendorCounts,
  67. })
  68. }
  69. // GetModelMeta 根据 ID 获取单条模型信息
  70. func GetModelMeta(c *gin.Context) {
  71. id, err := parseModelID(c)
  72. if err != nil {
  73. common.ApiError(c, err)
  74. return
  75. }
  76. var m model.Model
  77. if err := model.DB.First(&m, id).Error; err != nil {
  78. common.ApiError(c, err)
  79. return
  80. }
  81. enrichModels([]*model.Model{&m})
  82. common.ApiSuccess(c, &m)
  83. }
  84. // CreateModelMeta 新建模型
  85. func CreateModelMeta(c *gin.Context) {
  86. var m model.Model
  87. if err := c.ShouldBindJSON(&m); err != nil {
  88. common.ApiError(c, err)
  89. return
  90. }
  91. // 验证模型名称
  92. if errMsg := checkModelNameDuplicate(0, m.ModelName); errMsg != "" {
  93. common.ApiErrorMsg(c, errMsg)
  94. return
  95. }
  96. if err := m.Insert(); err != nil {
  97. common.ApiError(c, err)
  98. return
  99. }
  100. model.RefreshPricing()
  101. common.ApiSuccess(c, &m)
  102. }
  103. // UpdateModelMeta 更新模型
  104. func UpdateModelMeta(c *gin.Context) {
  105. statusOnly := c.Query("status_only") == "true"
  106. var m model.Model
  107. if err := c.ShouldBindJSON(&m); err != nil {
  108. common.ApiError(c, err)
  109. return
  110. }
  111. if m.Id == 0 {
  112. common.ApiErrorMsg(c, "缺少模型 ID")
  113. return
  114. }
  115. if statusOnly {
  116. // 只更新状态,防止误清空其他字段
  117. if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
  118. common.ApiError(c, err)
  119. return
  120. }
  121. } else {
  122. // 验证模型名称
  123. if errMsg := checkModelNameDuplicate(m.Id, m.ModelName); errMsg != "" {
  124. common.ApiErrorMsg(c, errMsg)
  125. return
  126. }
  127. if err := m.Update(); err != nil {
  128. common.ApiError(c, err)
  129. return
  130. }
  131. }
  132. model.RefreshPricing()
  133. common.ApiSuccess(c, &m)
  134. }
  135. // DeleteModelMeta 删除模型
  136. func DeleteModelMeta(c *gin.Context) {
  137. id, err := parseModelID(c)
  138. if err != nil {
  139. common.ApiError(c, err)
  140. return
  141. }
  142. if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
  143. common.ApiError(c, err)
  144. return
  145. }
  146. model.RefreshPricing()
  147. common.ApiSuccess(c, nil)
  148. }
  149. // enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
  150. func enrichModels(models []*model.Model) {
  151. if len(models) == 0 {
  152. return
  153. }
  154. exactNames, exactIdx, ruleIndices := classifyModels(models)
  155. // 处理精确匹配模型
  156. enrichExactModels(models, exactNames, exactIdx)
  157. // 处理规则匹配模型
  158. if len(ruleIndices) > 0 {
  159. enrichRuleModels(models, ruleIndices)
  160. }
  161. }
  162. // classifyModels 将模型分类为精确匹配和规则匹配
  163. func classifyModels(models []*model.Model) ([]string, map[string][]int, []int) {
  164. exactNames := make([]string, 0)
  165. exactIdx := make(map[string][]int)
  166. ruleIndices := make([]int, 0)
  167. for i, m := range models {
  168. if m == nil {
  169. continue
  170. }
  171. if m.NameRule == model.NameRuleExact {
  172. exactNames = append(exactNames, m.ModelName)
  173. exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
  174. } else {
  175. ruleIndices = append(ruleIndices, i)
  176. }
  177. }
  178. return exactNames, exactIdx, ruleIndices
  179. }
  180. // enrichExactModels 填充精确匹配模型的信息
  181. func enrichExactModels(models []*model.Model, exactNames []string, exactIdx map[string][]int) {
  182. if len(exactNames) == 0 {
  183. return
  184. }
  185. channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
  186. for name, indices := range exactIdx {
  187. chs := channelsByModel[name]
  188. for _, idx := range indices {
  189. mm := models[idx]
  190. if mm.Endpoints == "" {
  191. eps := model.GetModelSupportEndpointTypes(mm.ModelName)
  192. if b, err := json.Marshal(eps); err == nil {
  193. mm.Endpoints = string(b)
  194. }
  195. }
  196. mm.BoundChannels = chs
  197. mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
  198. mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
  199. }
  200. }
  201. }
  202. // enrichRuleModels 填充规则匹配模型的信息
  203. func enrichRuleModels(models []*model.Model, ruleIndices []int) {
  204. pricings := model.GetPricing()
  205. // 收集匹配信息
  206. matchedNamesByIdx, endpointSetByIdx, groupSetByIdx, quotaSetByIdx := collectRuleMatches(models, ruleIndices, pricings)
  207. // 批量查询渠道信息
  208. allMatched := extractAllMatchedNames(matchedNamesByIdx)
  209. matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
  210. // 回填模型信息
  211. fillRuleModelData(models, ruleIndices, matchedNamesByIdx, endpointSetByIdx, groupSetByIdx, quotaSetByIdx, matchedChannelsByModel)
  212. }
  213. // collectRuleMatches 收集规则模型的匹配信息
  214. func collectRuleMatches(models []*model.Model, ruleIndices []int, pricings []model.Pricing) (
  215. map[int][]string,
  216. map[int]map[constant.EndpointType]struct{},
  217. map[int]map[string]struct{},
  218. map[int]map[int]struct{},
  219. ) {
  220. matchedNamesByIdx := make(map[int][]string)
  221. endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
  222. groupSetByIdx := make(map[int]map[string]struct{})
  223. quotaSetByIdx := make(map[int]map[int]struct{})
  224. for _, p := range pricings {
  225. for _, idx := range ruleIndices {
  226. mm := models[idx]
  227. if !matchNameRule(p.ModelName, mm.ModelName, mm.NameRule) {
  228. continue
  229. }
  230. matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
  231. if endpointSetByIdx[idx] == nil {
  232. endpointSetByIdx[idx] = make(map[constant.EndpointType]struct{})
  233. }
  234. for _, et := range p.SupportedEndpointTypes {
  235. endpointSetByIdx[idx][et] = struct{}{}
  236. }
  237. if groupSetByIdx[idx] == nil {
  238. groupSetByIdx[idx] = make(map[string]struct{})
  239. }
  240. for _, g := range p.EnableGroup {
  241. groupSetByIdx[idx][g] = struct{}{}
  242. }
  243. if quotaSetByIdx[idx] == nil {
  244. quotaSetByIdx[idx] = make(map[int]struct{})
  245. }
  246. quotaSetByIdx[idx][p.QuotaType] = struct{}{}
  247. }
  248. }
  249. return matchedNamesByIdx, endpointSetByIdx, groupSetByIdx, quotaSetByIdx
  250. }
  251. // matchNameRule 根据规则匹配模型名称
  252. func matchNameRule(pricingModel, modelName string, nameRule int) bool {
  253. switch nameRule {
  254. case model.NameRulePrefix:
  255. return strings.HasPrefix(pricingModel, modelName)
  256. case model.NameRuleSuffix:
  257. return strings.HasSuffix(pricingModel, modelName)
  258. case model.NameRuleContains:
  259. return strings.Contains(pricingModel, modelName)
  260. default:
  261. return false
  262. }
  263. }
  264. // extractAllMatchedNames 提取所有匹配的模型名称
  265. func extractAllMatchedNames(matchedNamesByIdx map[int][]string) []string {
  266. allMatchedSet := make(map[string]struct{})
  267. for _, names := range matchedNamesByIdx {
  268. for _, n := range names {
  269. allMatchedSet[n] = struct{}{}
  270. }
  271. }
  272. allMatched := make([]string, 0, len(allMatchedSet))
  273. for n := range allMatchedSet {
  274. allMatched = append(allMatched, n)
  275. }
  276. return allMatched
  277. }
  278. // fillRuleModelData 回填规则模型的数据
  279. func fillRuleModelData(
  280. models []*model.Model,
  281. ruleIndices []int,
  282. matchedNamesByIdx map[int][]string,
  283. endpointSetByIdx map[int]map[constant.EndpointType]struct{},
  284. groupSetByIdx map[int]map[string]struct{},
  285. quotaSetByIdx map[int]map[int]struct{},
  286. matchedChannelsByModel map[string][]model.BoundChannel,
  287. ) {
  288. for _, idx := range ruleIndices {
  289. mm := models[idx]
  290. // 填充端点
  291. if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
  292. eps := make([]constant.EndpointType, 0, len(es))
  293. for et := range es {
  294. eps = append(eps, et)
  295. }
  296. if b, err := json.Marshal(eps); err == nil {
  297. mm.Endpoints = string(b)
  298. }
  299. }
  300. // 填充分组
  301. if gs, ok := groupSetByIdx[idx]; ok {
  302. groups := make([]string, 0, len(gs))
  303. for g := range gs {
  304. groups = append(groups, g)
  305. }
  306. mm.EnableGroups = groups
  307. }
  308. // 填充配额类型
  309. if qs, ok := quotaSetByIdx[idx]; ok {
  310. arr := make([]int, 0, len(qs))
  311. for k := range qs {
  312. arr = append(arr, k)
  313. }
  314. sort.Ints(arr)
  315. mm.QuotaTypes = arr
  316. }
  317. // 填充渠道
  318. names := matchedNamesByIdx[idx]
  319. channelSet := make(map[string]model.BoundChannel)
  320. for _, n := range names {
  321. for _, ch := range matchedChannelsByModel[n] {
  322. key := ch.Name + "_" + strconv.Itoa(ch.Type)
  323. channelSet[key] = ch
  324. }
  325. }
  326. if len(channelSet) > 0 {
  327. chs := make([]model.BoundChannel, 0, len(channelSet))
  328. for _, ch := range channelSet {
  329. chs = append(chs, ch)
  330. }
  331. mm.BoundChannels = chs
  332. }
  333. // 填充匹配信息
  334. mm.MatchedModels = names
  335. mm.MatchedCount = len(names)
  336. }
  337. }