2
0

prefill_group.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package model
  2. import (
  3. "database/sql/driver"
  4. "encoding/json"
  5. "github.com/QuantumNous/new-api/common"
  6. "gorm.io/gorm"
  7. )
  8. // PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
  9. // Name 字段保持唯一,用于在前端下拉框中展示。
  10. // Type 字段用于区分组的类别,可选值如:model、tag、endpoint。
  11. // Items 字段使用 JSON 数组保存对应类型的字符串集合,示例:
  12. // ["gpt-4o", "gpt-3.5-turbo"]
  13. // 设计遵循 3NF,避免冗余,提供灵活扩展能力。
  14. // JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取
  15. type JSONValue json.RawMessage
  16. // Value 实现 driver.Valuer 接口,用于数据库写入
  17. func (j JSONValue) Value() (driver.Value, error) {
  18. if j == nil {
  19. return nil, nil
  20. }
  21. return []byte(j), nil
  22. }
  23. // Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
  24. func (j *JSONValue) Scan(value interface{}) error {
  25. switch v := value.(type) {
  26. case nil:
  27. *j = nil
  28. return nil
  29. case []byte:
  30. // 拷贝底层字节,避免保留底层缓冲区
  31. b := make([]byte, len(v))
  32. copy(b, v)
  33. *j = JSONValue(b)
  34. return nil
  35. case string:
  36. *j = JSONValue([]byte(v))
  37. return nil
  38. default:
  39. // 其他类型尝试序列化为 JSON
  40. b, err := json.Marshal(v)
  41. if err != nil {
  42. return err
  43. }
  44. *j = JSONValue(b)
  45. return nil
  46. }
  47. }
  48. // MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
  49. func (j JSONValue) MarshalJSON() ([]byte, error) {
  50. if j == nil {
  51. return []byte("null"), nil
  52. }
  53. return j, nil
  54. }
  55. // UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
  56. func (j *JSONValue) UnmarshalJSON(data []byte) error {
  57. if data == nil {
  58. *j = nil
  59. return nil
  60. }
  61. b := make([]byte, len(data))
  62. copy(b, data)
  63. *j = JSONValue(b)
  64. return nil
  65. }
  66. type PrefillGroup struct {
  67. Id int `json:"id"`
  68. Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
  69. Type string `json:"type" gorm:"size:32;index;not null"`
  70. Items JSONValue `json:"items" gorm:"type:json"`
  71. Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
  72. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  73. UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
  74. DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
  75. }
  76. // Insert 新建组
  77. func (g *PrefillGroup) Insert() error {
  78. now := common.GetTimestamp()
  79. g.CreatedTime = now
  80. g.UpdatedTime = now
  81. return DB.Create(g).Error
  82. }
  83. // IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
  84. func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
  85. if name == "" {
  86. return false, nil
  87. }
  88. var cnt int64
  89. err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
  90. return cnt > 0, err
  91. }
  92. // Update 更新组
  93. func (g *PrefillGroup) Update() error {
  94. g.UpdatedTime = common.GetTimestamp()
  95. return DB.Save(g).Error
  96. }
  97. // DeleteByID 根据 ID 删除组
  98. func DeletePrefillGroupByID(id int) error {
  99. return DB.Delete(&PrefillGroup{}, id).Error
  100. }
  101. // GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
  102. func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
  103. var groups []*PrefillGroup
  104. query := DB.Model(&PrefillGroup{})
  105. if groupType != "" {
  106. query = query.Where("type = ?", groupType)
  107. }
  108. if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
  109. return nil, err
  110. }
  111. return groups, nil
  112. }