group.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. package model
  2. import (
  3. "errors"
  4. "strings"
  5. "time"
  6. "github.com/labring/aiproxy/core/common"
  7. log "github.com/sirupsen/logrus"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/clause"
  10. )
  11. const (
  12. ErrGroupNotFound = "group"
  13. )
  14. const (
  15. GroupStatusEnabled = 1
  16. GroupStatusDisabled = 2
  17. GroupStatusInternal = 3
  18. )
  19. type Group struct {
  20. CreatedAt time.Time `json:"created_at"`
  21. ID string `json:"id" gorm:"primaryKey"`
  22. Tokens []Token `json:"-" gorm:"foreignKey:GroupID"`
  23. GroupModelConfigs []GroupModelConfig `json:"-" gorm:"foreignKey:GroupID"`
  24. PublicMCPReusingParams []PublicMCPReusingParam `json:"-" gorm:"foreignKey:GroupID"`
  25. GroupMCPs []GroupMCP `json:"-" gorm:"foreignKey:GroupID"`
  26. Status int `json:"status" gorm:"default:1;index"`
  27. RPMRatio float64 `json:"rpm_ratio,omitempty" gorm:"index"`
  28. TPMRatio float64 `json:"tpm_ratio,omitempty" gorm:"index"`
  29. UsedAmount float64 `json:"used_amount" gorm:"index"`
  30. RequestCount int `json:"request_count" gorm:"index"`
  31. AvailableSets []string `json:"available_sets,omitempty" gorm:"serializer:fastjson;type:text"`
  32. BalanceAlertEnabled bool `gorm:"default:false" json:"balance_alert_enabled"`
  33. BalanceAlertThreshold float64 `gorm:"default:0" json:"balance_alert_threshold"`
  34. }
  35. func (g *Group) BeforeDelete(tx *gorm.DB) (err error) {
  36. err = tx.Model(&Token{}).Where("group_id = ?", g.ID).Delete(&Token{}).Error
  37. if err != nil {
  38. return err
  39. }
  40. err = tx.Model(&PublicMCPReusingParam{}).
  41. Where("group_id = ?", g.ID).
  42. Delete(&PublicMCPReusingParam{}).
  43. Error
  44. if err != nil {
  45. return err
  46. }
  47. err = tx.Model(&GroupMCP{}).Where("group_id = ?", g.ID).Delete(&GroupMCP{}).Error
  48. if err != nil {
  49. return err
  50. }
  51. return tx.Model(&GroupModelConfig{}).
  52. Where("group_id = ?", g.ID).
  53. Delete(&GroupModelConfig{}).
  54. Error
  55. }
  56. func getGroupOrder(order string) string {
  57. prefix, suffix, _ := strings.Cut(order, "-")
  58. switch prefix {
  59. case "id", "request_count", "status", "created_at", "used_amount":
  60. switch suffix {
  61. case "asc":
  62. return prefix + " asc"
  63. default:
  64. return prefix + " desc"
  65. }
  66. default:
  67. return "id desc"
  68. }
  69. }
  70. func GetGroups(
  71. page, perPage int,
  72. order string,
  73. onlyDisabled bool,
  74. ) (groups []*Group, total int64, err error) {
  75. tx := DB.Model(&Group{})
  76. if onlyDisabled {
  77. tx = tx.Where("status = ?", GroupStatusDisabled)
  78. }
  79. err = tx.Count(&total).Error
  80. if err != nil {
  81. return nil, 0, err
  82. }
  83. if total <= 0 {
  84. return nil, 0, nil
  85. }
  86. limit, offset := toLimitOffset(page, perPage)
  87. err = tx.
  88. Order(getGroupOrder(order)).
  89. Limit(limit).
  90. Offset(offset).
  91. Find(&groups).
  92. Error
  93. return groups, total, err
  94. }
  95. func GetGroupByID(id string, preloadGroupModelConfigs bool) (*Group, error) {
  96. if id == "" {
  97. return nil, errors.New("group id is empty")
  98. }
  99. group := Group{}
  100. tx := DB.Where("id = ?", id)
  101. if preloadGroupModelConfigs {
  102. tx = tx.Preload("GroupModelConfigs")
  103. }
  104. err := tx.First(&group).Error
  105. return &group, HandleNotFound(err, ErrGroupNotFound)
  106. }
  107. func DeleteGroupByID(id string) (err error) {
  108. if id == "" {
  109. return errors.New("group id is empty")
  110. }
  111. defer func() {
  112. if err == nil {
  113. if err := CacheDeleteGroup(id); err != nil {
  114. log.Error("cache delete group failed: " + err.Error())
  115. }
  116. if _, err := DeleteGroupLogs(id); err != nil {
  117. log.Error("delete group logs failed: " + err.Error())
  118. }
  119. }
  120. }()
  121. result := DB.Delete(&Group{ID: id})
  122. return HandleUpdateResult(result, ErrGroupNotFound)
  123. }
  124. func DeleteGroupsByIDs(ids []string) (err error) {
  125. if len(ids) == 0 {
  126. return nil
  127. }
  128. groups := make([]Group, len(ids))
  129. defer func() {
  130. if err == nil {
  131. for _, group := range groups {
  132. if err := CacheDeleteGroup(group.ID); err != nil {
  133. log.Error("cache delete group failed: " + err.Error())
  134. }
  135. if _, err := DeleteGroupLogs(group.ID); err != nil {
  136. log.Error("delete group logs failed: " + err.Error())
  137. }
  138. }
  139. }
  140. }()
  141. return DB.Transaction(func(tx *gorm.DB) error {
  142. return tx.
  143. Clauses(clause.Returning{
  144. Columns: []clause.Column{
  145. {Name: "id"},
  146. },
  147. }).
  148. Where("id IN (?)", ids).
  149. Delete(&groups).
  150. Error
  151. })
  152. }
  153. func UpdateGroup(id string, group *Group) (err error) {
  154. if id == "" {
  155. return errors.New("group id is empty")
  156. }
  157. defer func() {
  158. if err == nil {
  159. if err := CacheDeleteGroup(id); err != nil {
  160. log.Error("cache delete group failed: " + err.Error())
  161. }
  162. }
  163. }()
  164. selects := []string{
  165. "rpm_ratio",
  166. "tpm_ratio",
  167. "available_sets",
  168. "balance_alert_enabled",
  169. "balance_alert_threshold",
  170. }
  171. if group.Status != 0 {
  172. selects = append(selects, "status")
  173. }
  174. result := DB.
  175. Clauses(clause.Returning{}).
  176. Where("id = ?", id).
  177. Select(selects).
  178. Updates(group)
  179. return HandleUpdateResult(result, ErrGroupNotFound)
  180. }
  181. func UpdateGroupUsedAmountAndRequestCount(id string, amount float64, count int) (err error) {
  182. group := &Group{}
  183. defer func() {
  184. if amount > 0 && err == nil {
  185. if err := CacheUpdateGroupUsedAmountOnlyIncrease(group.ID, group.UsedAmount); err != nil {
  186. log.Error("update group used amount in cache failed: " + err.Error())
  187. }
  188. }
  189. }()
  190. result := DB.
  191. Model(group).
  192. Clauses(clause.Returning{
  193. Columns: []clause.Column{
  194. {Name: "used_amount"},
  195. },
  196. }).
  197. Where("id = ?", id).
  198. Updates(map[string]any{
  199. "used_amount": gorm.Expr("used_amount + ?", amount),
  200. "request_count": gorm.Expr("request_count + ?", count),
  201. })
  202. return HandleUpdateResult(result, ErrGroupNotFound)
  203. }
  204. func UpdateGroupRPMRatio(id string, rpmRatio float64) (err error) {
  205. defer func() {
  206. if err == nil {
  207. if err := CacheUpdateGroupRPMRatio(id, rpmRatio); err != nil {
  208. log.Error("cache update group rpm failed: " + err.Error())
  209. }
  210. }
  211. }()
  212. result := DB.Model(&Group{}).Where("id = ?", id).Update("rpm_ratio", rpmRatio)
  213. return HandleUpdateResult(result, ErrGroupNotFound)
  214. }
  215. func UpdateGroupTPMRatio(id string, tpmRatio float64) (err error) {
  216. defer func() {
  217. if err == nil {
  218. if err := CacheUpdateGroupTPMRatio(id, tpmRatio); err != nil {
  219. log.Error("cache update group tpm ratio failed: " + err.Error())
  220. }
  221. }
  222. }()
  223. result := DB.Model(&Group{}).Where("id = ?", id).Update("tpm_ratio", tpmRatio)
  224. return HandleUpdateResult(result, ErrGroupNotFound)
  225. }
  226. func UpdateGroupStatus(id string, status int) (err error) {
  227. defer func() {
  228. if err == nil {
  229. if err := CacheUpdateGroupStatus(id, status); err != nil {
  230. log.Error("cache update group status failed: " + err.Error())
  231. }
  232. }
  233. }()
  234. result := DB.Model(&Group{}).Where("id = ?", id).Update("status", status)
  235. return HandleUpdateResult(result, ErrGroupNotFound)
  236. }
  237. func UpdateGroupsStatus(ids []string, status int) (rowsAffected int64, err error) {
  238. defer func() {
  239. if err == nil {
  240. for _, id := range ids {
  241. if err := CacheUpdateGroupStatus(id, status); err != nil {
  242. log.Error("cache update group status failed: " + err.Error())
  243. }
  244. }
  245. }
  246. }()
  247. result := DB.Model(&Group{}).
  248. Where("id IN (?) AND status != ?", ids, status).
  249. Update("status", status)
  250. return result.RowsAffected, result.Error
  251. }
  252. func SearchGroup(
  253. keyword string,
  254. page, perPage int,
  255. order string,
  256. status int,
  257. ) (groups []*Group, total int64, err error) {
  258. tx := DB.Model(&Group{})
  259. if status != 0 {
  260. tx = tx.Where("status = ?", status)
  261. }
  262. if common.UsingPostgreSQL {
  263. tx = tx.Where("id ILIKE ? OR available_sets ILIKE ?", "%"+keyword+"%", "%"+keyword+"%")
  264. } else {
  265. tx = tx.Where("id LIKE ? OR available_sets LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
  266. }
  267. err = tx.Count(&total).Error
  268. if err != nil {
  269. return nil, 0, err
  270. }
  271. if total <= 0 {
  272. return nil, 0, nil
  273. }
  274. limit, offset := toLimitOffset(page, perPage)
  275. err = tx.
  276. Order(getGroupOrder(order)).
  277. Limit(limit).
  278. Offset(offset).
  279. Find(&groups).
  280. Error
  281. return groups, total, err
  282. }
  283. func CreateGroup(group *Group) error {
  284. return DB.Create(group).Error
  285. }