model_meta.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package model
  2. import (
  3. "strconv"
  4. "github.com/QuantumNous/new-api/common"
  5. "gorm.io/gorm"
  6. )
  7. const (
  8. NameRuleExact = iota
  9. NameRulePrefix
  10. NameRuleContains
  11. NameRuleSuffix
  12. )
  13. type BoundChannel struct {
  14. Name string `json:"name"`
  15. Type int `json:"type"`
  16. }
  17. type Model struct {
  18. Id int `json:"id"`
  19. ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
  20. Description string `json:"description,omitempty" gorm:"type:text"`
  21. Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
  22. Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
  23. VendorID int `json:"vendor_id,omitempty" gorm:"index"`
  24. Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
  25. Status int `json:"status" gorm:"default:1"`
  26. SyncOfficial int `json:"sync_official" gorm:"default:1"`
  27. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  28. UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
  29. DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
  30. BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
  31. EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
  32. QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
  33. NameRule int `json:"name_rule" gorm:"default:0"`
  34. MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
  35. MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
  36. }
  37. func (mi *Model) Insert() error {
  38. now := common.GetTimestamp()
  39. mi.CreatedTime = now
  40. mi.UpdatedTime = now
  41. // 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1)
  42. originalStatus := mi.Status
  43. originalSyncOfficial := mi.SyncOfficial
  44. // 先创建记录(GORM 会对零值字段应用默认值)
  45. if err := DB.Create(mi).Error; err != nil {
  46. return err
  47. }
  48. // 使用保存的原始值进行更新,确保零值能正确保存
  49. return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
  50. "status": originalStatus,
  51. "sync_official": originalSyncOfficial,
  52. }).Error
  53. }
  54. func IsModelNameDuplicated(id int, name string) (bool, error) {
  55. if name == "" {
  56. return false, nil
  57. }
  58. var cnt int64
  59. err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
  60. return cnt > 0, err
  61. }
  62. func (mi *Model) Update() error {
  63. mi.UpdatedTime = common.GetTimestamp()
  64. // 使用 Select 强制更新所有字段,包括零值
  65. return DB.Model(&Model{}).Where("id = ?", mi.Id).
  66. Select("model_name", "description", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "updated_time").
  67. Updates(mi).Error
  68. }
  69. func (mi *Model) Delete() error {
  70. return DB.Delete(mi).Error
  71. }
  72. func GetVendorModelCounts() (map[int64]int64, error) {
  73. var stats []struct {
  74. VendorID int64
  75. Count int64
  76. }
  77. if err := DB.Model(&Model{}).
  78. Select("vendor_id as vendor_id, count(*) as count").
  79. Group("vendor_id").
  80. Scan(&stats).Error; err != nil {
  81. return nil, err
  82. }
  83. m := make(map[int64]int64, len(stats))
  84. for _, s := range stats {
  85. m[s.VendorID] = s.Count
  86. }
  87. return m, nil
  88. }
  89. func GetAllModels(offset int, limit int) ([]*Model, error) {
  90. var models []*Model
  91. err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
  92. return models, err
  93. }
  94. func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
  95. result := make(map[string][]BoundChannel)
  96. if len(modelNames) == 0 {
  97. return result, nil
  98. }
  99. type row struct {
  100. Model string
  101. Name string
  102. Type int
  103. }
  104. var rows []row
  105. err := DB.Table("channels").
  106. Select("abilities.model as model, channels.name as name, channels.type as type").
  107. Joins("JOIN abilities ON abilities.channel_id = channels.id").
  108. Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
  109. Distinct().
  110. Scan(&rows).Error
  111. if err != nil {
  112. return nil, err
  113. }
  114. for _, r := range rows {
  115. result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
  116. }
  117. return result, nil
  118. }
  119. func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
  120. var models []*Model
  121. db := DB.Model(&Model{})
  122. if keyword != "" {
  123. like := "%" + keyword + "%"
  124. db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
  125. }
  126. if vendor != "" {
  127. if vid, err := strconv.Atoi(vendor); err == nil {
  128. db = db.Where("models.vendor_id = ?", vid)
  129. } else {
  130. db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
  131. }
  132. }
  133. var total int64
  134. if err := db.Count(&total).Error; err != nil {
  135. return nil, 0, err
  136. }
  137. if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
  138. return nil, 0, err
  139. }
  140. return models, total, nil
  141. }