group.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package model
  2. import (
  3. "errors"
  4. "strings"
  5. "time"
  6. "github.com/labring/aiproxy/core/common"
  7. log "github.com/sirupsen/logrus"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/clause"
  10. )
  11. const (
  12. ErrGroupNotFound = "group"
  13. )
  14. const (
  15. GroupStatusEnabled = 1
  16. GroupStatusDisabled = 2
  17. GroupStatusInternal = 3
  18. )
  19. type Group struct {
  20. CreatedAt time.Time `json:"created_at"`
  21. ID string `json:"id" gorm:"size:64;primaryKey"`
  22. Tokens []Token `json:"-" gorm:"foreignKey:GroupID"`
  23. GroupModelConfigs []GroupModelConfig `json:"-" gorm:"foreignKey:GroupID"`
  24. PublicMCPReusingParams []PublicMCPReusingParam `json:"-" gorm:"foreignKey:GroupID"`
  25. GroupMCPs []GroupMCP `json:"-" gorm:"foreignKey:GroupID"`
  26. Status int `json:"status" gorm:"default:1;index"`
  27. RPMRatio float64 `json:"rpm_ratio,omitempty" gorm:"index"`
  28. TPMRatio float64 `json:"tpm_ratio,omitempty" gorm:"index"`
  29. UsedAmount float64 `json:"used_amount" gorm:"index"`
  30. RequestCount int `json:"request_count" gorm:"index"`
  31. AvailableSets []string `json:"available_sets,omitempty" gorm:"serializer:fastjson;type:text"`
  32. BalanceAlertEnabled bool `gorm:"default:false" json:"balance_alert_enabled"`
  33. BalanceAlertThreshold float64 `gorm:"default:0" json:"balance_alert_threshold"`
  34. }
  35. func (g *Group) BeforeSave(_ *gorm.DB) error {
  36. if len(g.ID) > 64 {
  37. return errors.New("group id length too long")
  38. }
  39. return nil
  40. }
  41. func (g *Group) BeforeDelete(tx *gorm.DB) (err error) {
  42. err = tx.Model(&Token{}).Where("group_id = ?", g.ID).Delete(&Token{}).Error
  43. if err != nil {
  44. return err
  45. }
  46. err = tx.Model(&PublicMCPReusingParam{}).
  47. Where("group_id = ?", g.ID).
  48. Delete(&PublicMCPReusingParam{}).
  49. Error
  50. if err != nil {
  51. return err
  52. }
  53. err = tx.Model(&GroupMCP{}).Where("group_id = ?", g.ID).Delete(&GroupMCP{}).Error
  54. if err != nil {
  55. return err
  56. }
  57. return tx.Model(&GroupModelConfig{}).
  58. Where("group_id = ?", g.ID).
  59. Delete(&GroupModelConfig{}).
  60. Error
  61. }
  62. func getGroupOrder(order string) string {
  63. prefix, suffix, _ := strings.Cut(order, "-")
  64. switch prefix {
  65. case "id", "request_count", "status", "created_at", "used_amount":
  66. switch suffix {
  67. case "asc":
  68. return prefix + " asc"
  69. default:
  70. return prefix + " desc"
  71. }
  72. default:
  73. return "id desc"
  74. }
  75. }
  76. func GetGroups(
  77. page, perPage int,
  78. order string,
  79. onlyDisabled bool,
  80. ) (groups []*Group, total int64, err error) {
  81. tx := DB.Model(&Group{})
  82. if onlyDisabled {
  83. tx = tx.Where("status = ?", GroupStatusDisabled)
  84. }
  85. err = tx.Count(&total).Error
  86. if err != nil {
  87. return nil, 0, err
  88. }
  89. if total <= 0 {
  90. return nil, 0, nil
  91. }
  92. limit, offset := toLimitOffset(page, perPage)
  93. err = tx.
  94. Order(getGroupOrder(order)).
  95. Limit(limit).
  96. Offset(offset).
  97. Find(&groups).
  98. Error
  99. return groups, total, err
  100. }
  101. func GetGroupByID(id string, preloadGroupModelConfigs bool) (*Group, error) {
  102. if id == "" {
  103. return nil, errors.New("group id is empty")
  104. }
  105. group := Group{}
  106. tx := DB.Where("id = ?", id)
  107. if preloadGroupModelConfigs {
  108. tx = tx.Preload("GroupModelConfigs")
  109. }
  110. err := tx.First(&group).Error
  111. return &group, HandleNotFound(err, ErrGroupNotFound)
  112. }
  113. func DeleteGroupByID(id string) (err error) {
  114. if id == "" {
  115. return errors.New("group id is empty")
  116. }
  117. defer func() {
  118. if err == nil {
  119. if err := CacheDeleteGroup(id); err != nil {
  120. log.Error("cache delete group failed: " + err.Error())
  121. }
  122. if _, err := DeleteGroupLogs(id); err != nil {
  123. log.Error("delete group logs failed: " + err.Error())
  124. }
  125. }
  126. }()
  127. result := DB.Delete(&Group{ID: id})
  128. return HandleUpdateResult(result, ErrGroupNotFound)
  129. }
  130. func DeleteGroupsByIDs(ids []string) (err error) {
  131. if len(ids) == 0 {
  132. return nil
  133. }
  134. groups := make([]Group, len(ids))
  135. defer func() {
  136. if err == nil {
  137. for _, group := range groups {
  138. if err := CacheDeleteGroup(group.ID); err != nil {
  139. log.Error("cache delete group failed: " + err.Error())
  140. }
  141. if _, err := DeleteGroupLogs(group.ID); err != nil {
  142. log.Error("delete group logs failed: " + err.Error())
  143. }
  144. }
  145. }
  146. }()
  147. return DB.Transaction(func(tx *gorm.DB) error {
  148. return tx.
  149. Clauses(clause.Returning{
  150. Columns: []clause.Column{
  151. {Name: "id"},
  152. },
  153. }).
  154. Where("id IN (?)", ids).
  155. Delete(&groups).
  156. Error
  157. })
  158. }
  159. type UpdateGroupRequest struct {
  160. Status int `json:"status"`
  161. RPMRatio *float64 `json:"rpm_ratio,omitempty"`
  162. TPMRatio *float64 `json:"tpm_ratio,omitempty"`
  163. AvailableSets *[]string `json:"available_sets,omitempty"`
  164. BalanceAlertEnabled *bool `json:"balance_alert_enabled"`
  165. BalanceAlertThreshold *float64 `json:"balance_alert_threshold"`
  166. }
  167. func UpdateGroup(id string, update UpdateGroupRequest) (group *Group, err error) {
  168. if id == "" {
  169. return nil, errors.New("group id is empty")
  170. }
  171. group = &Group{
  172. ID: id,
  173. Status: update.Status,
  174. }
  175. defer func() {
  176. if err == nil {
  177. if err := CacheDeleteGroup(id); err != nil {
  178. log.Error("cache delete group failed: " + err.Error())
  179. }
  180. }
  181. }()
  182. selects := []string{}
  183. if update.RPMRatio != nil {
  184. group.RPMRatio = *update.RPMRatio
  185. selects = append(selects, "rpm_ratio")
  186. }
  187. if update.TPMRatio != nil {
  188. group.TPMRatio = *update.TPMRatio
  189. selects = append(selects, "tpm_ratio")
  190. }
  191. if update.AvailableSets != nil {
  192. group.AvailableSets = *update.AvailableSets
  193. selects = append(selects, "available_sets")
  194. }
  195. if update.BalanceAlertEnabled != nil {
  196. group.BalanceAlertEnabled = *update.BalanceAlertEnabled
  197. selects = append(selects, "balance_alert_enabled")
  198. }
  199. if update.BalanceAlertThreshold != nil {
  200. group.BalanceAlertThreshold = *update.BalanceAlertThreshold
  201. selects = append(selects, "balance_alert_threshold")
  202. }
  203. if group.Status != 0 {
  204. selects = append(selects, "status")
  205. }
  206. result := DB.
  207. Clauses(clause.Returning{}).
  208. Where("id = ?", id).
  209. Select(selects).
  210. Updates(group)
  211. return group, HandleUpdateResult(result, ErrGroupNotFound)
  212. }
  213. func UpdateGroupUsedAmountAndRequestCount(id string, amount float64, count int) (err error) {
  214. group := &Group{}
  215. defer func() {
  216. if amount > 0 && err == nil {
  217. if err := CacheUpdateGroupUsedAmountOnlyIncrease(group.ID, group.UsedAmount); err != nil {
  218. log.Error("update group used amount in cache failed: " + err.Error())
  219. }
  220. }
  221. }()
  222. result := DB.
  223. Model(group).
  224. Clauses(clause.Returning{
  225. Columns: []clause.Column{
  226. {Name: "used_amount"},
  227. },
  228. }).
  229. Where("id = ?", id).
  230. Updates(map[string]any{
  231. "used_amount": gorm.Expr("used_amount + ?", amount),
  232. "request_count": gorm.Expr("request_count + ?", count),
  233. })
  234. return HandleUpdateResult(result, ErrGroupNotFound)
  235. }
  236. func UpdateGroupRPMRatio(id string, rpmRatio float64) (err error) {
  237. defer func() {
  238. if err == nil {
  239. if err := CacheUpdateGroupRPMRatio(id, rpmRatio); err != nil {
  240. log.Error("cache update group rpm failed: " + err.Error())
  241. }
  242. }
  243. }()
  244. result := DB.Model(&Group{}).Where("id = ?", id).Update("rpm_ratio", rpmRatio)
  245. return HandleUpdateResult(result, ErrGroupNotFound)
  246. }
  247. func UpdateGroupTPMRatio(id string, tpmRatio float64) (err error) {
  248. defer func() {
  249. if err == nil {
  250. if err := CacheUpdateGroupTPMRatio(id, tpmRatio); err != nil {
  251. log.Error("cache update group tpm ratio failed: " + err.Error())
  252. }
  253. }
  254. }()
  255. result := DB.Model(&Group{}).Where("id = ?", id).Update("tpm_ratio", tpmRatio)
  256. return HandleUpdateResult(result, ErrGroupNotFound)
  257. }
  258. func UpdateGroupStatus(id string, status int) (err error) {
  259. defer func() {
  260. if err == nil {
  261. if err := CacheUpdateGroupStatus(id, status); err != nil {
  262. log.Error("cache update group status failed: " + err.Error())
  263. }
  264. }
  265. }()
  266. result := DB.Model(&Group{}).Where("id = ?", id).Update("status", status)
  267. return HandleUpdateResult(result, ErrGroupNotFound)
  268. }
  269. func UpdateGroupsStatus(ids []string, status int) (rowsAffected int64, err error) {
  270. defer func() {
  271. if err == nil {
  272. for _, id := range ids {
  273. if err := CacheUpdateGroupStatus(id, status); err != nil {
  274. log.Error("cache update group status failed: " + err.Error())
  275. }
  276. }
  277. }
  278. }()
  279. result := DB.Model(&Group{}).
  280. Where("id IN (?) AND status != ?", ids, status).
  281. Update("status", status)
  282. return result.RowsAffected, result.Error
  283. }
  284. func SearchGroup(
  285. keyword string,
  286. page, perPage int,
  287. order string,
  288. status int,
  289. ) (groups []*Group, total int64, err error) {
  290. tx := DB.Model(&Group{})
  291. if status != 0 {
  292. tx = tx.Where("status = ?", status)
  293. }
  294. if !common.UsingSQLite {
  295. tx = tx.Where("id ILIKE ? OR available_sets ILIKE ?", "%"+keyword+"%", "%"+keyword+"%")
  296. } else {
  297. tx = tx.Where("id LIKE ? OR available_sets LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
  298. }
  299. err = tx.Count(&total).Error
  300. if err != nil {
  301. return nil, 0, err
  302. }
  303. if total <= 0 {
  304. return nil, 0, nil
  305. }
  306. limit, offset := toLimitOffset(page, perPage)
  307. err = tx.
  308. Order(getGroupOrder(order)).
  309. Limit(limit).
  310. Offset(offset).
  311. Find(&groups).
  312. Error
  313. return groups, total, err
  314. }
  315. func CreateGroup(group *Group) error {
  316. return DB.Create(group).Error
  317. }