channel.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package model
  2. import (
  3. "context"
  4. "fmt"
  5. "slices"
  6. "strings"
  7. "time"
  8. "github.com/bytedance/sonic"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/monitor"
  12. "github.com/labring/aiproxy/core/relay/mode"
  13. "gorm.io/gorm"
  14. "gorm.io/gorm/clause"
  15. )
  16. const (
  17. ErrChannelNotFound = "channel"
  18. )
  19. const (
  20. ChannelStatusUnknown = 0
  21. ChannelStatusEnabled = 1
  22. ChannelStatusDisabled = 2
  23. )
  24. const (
  25. ChannelDefaultSet = "default"
  26. )
  27. type Channel struct {
  28. DeletedAt gorm.DeletedAt `gorm:"index" json:"-" yaml:"-"`
  29. CreatedAt time.Time `gorm:"index" json:"created_at" yaml:"-"`
  30. LastTestErrorAt time.Time ` json:"last_test_error_at" yaml:"-"`
  31. ChannelTests []*ChannelTest `gorm:"foreignKey:ChannelID;references:ID" json:"channel_tests,omitempty" yaml:"-"`
  32. BalanceUpdatedAt time.Time ` json:"balance_updated_at" yaml:"-"`
  33. ModelMapping map[string]string `gorm:"serializer:fastjson;type:text" json:"model_mapping" yaml:"model_mapping,omitempty"`
  34. Key string `gorm:"type:text;index:,length:191" json:"key" yaml:"key,omitempty"`
  35. Name string `gorm:"size:64;index" json:"name" yaml:"name,omitempty"`
  36. BaseURL string `gorm:"size:128;index" json:"base_url" yaml:"base_url,omitempty"`
  37. Models []string `gorm:"serializer:fastjson;type:text" json:"models" yaml:"models,omitempty"`
  38. Balance float64 ` json:"balance" yaml:"balance,omitempty"`
  39. ID int `gorm:"primaryKey" json:"id" yaml:"id,omitempty"`
  40. UsedAmount float64 `gorm:"index" json:"used_amount" yaml:"-"`
  41. RequestCount int `gorm:"index" json:"request_count" yaml:"-"`
  42. RetryCount int `gorm:"index" json:"retry_count" yaml:"-"`
  43. Status int `gorm:"default:1;index" json:"status" yaml:"status,omitempty"`
  44. Type ChannelType `gorm:"default:0;index" json:"type" yaml:"type,omitempty"`
  45. Priority int32 ` json:"priority" yaml:"priority,omitempty"`
  46. EnabledAutoBalanceCheck bool ` json:"enabled_auto_balance_check" yaml:"enabled_auto_balance_check,omitempty"`
  47. BalanceThreshold float64 ` json:"balance_threshold" yaml:"balance_threshold,omitempty"`
  48. Configs ChannelConfigs `gorm:"serializer:fastjson;type:text" json:"configs,omitempty" yaml:"configs,omitempty"`
  49. Sets []string `gorm:"serializer:fastjson;type:text" json:"sets,omitempty" yaml:"sets,omitempty"`
  50. }
  51. func (c *Channel) GetSets() []string {
  52. if len(c.Sets) == 0 {
  53. return []string{ChannelDefaultSet}
  54. }
  55. return c.Sets
  56. }
  57. func (c *Channel) BeforeDelete(tx *gorm.DB) (err error) {
  58. return tx.Model(&ChannelTest{}).Where("channel_id = ?", c.ID).Delete(&ChannelTest{}).Error
  59. }
  60. func (c *Channel) GetBalanceThreshold() float64 {
  61. if c.BalanceThreshold < 0 {
  62. return 0
  63. }
  64. return c.BalanceThreshold
  65. }
  66. const (
  67. DefaultPriority = 10
  68. )
  69. func (c *Channel) GetPriority() int32 {
  70. if c.Priority == 0 {
  71. return DefaultPriority
  72. }
  73. return c.Priority
  74. }
  75. type ChannelConfigs map[string]any
  76. func (c ChannelConfigs) LoadConfig(config any) error {
  77. if len(c) == 0 {
  78. return nil
  79. }
  80. v, err := sonic.Marshal(c)
  81. if err != nil {
  82. return err
  83. }
  84. return sonic.Unmarshal(v, config)
  85. }
  86. func GetModelConfigWithModels(models []string) ([]string, []string, error) {
  87. if len(models) == 0 || config.DisableModelConfig {
  88. return models, nil, nil
  89. }
  90. where := DB.Model(&ModelConfig{}).Where("model IN ?", models)
  91. var count int64
  92. if err := where.Count(&count).Error; err != nil {
  93. return nil, nil, err
  94. }
  95. if count == 0 {
  96. return nil, models, nil
  97. }
  98. if count == int64(len(models)) {
  99. return models, nil, nil
  100. }
  101. var foundModels []string
  102. if err := where.Pluck("model", &foundModels).Error; err != nil {
  103. return nil, nil, err
  104. }
  105. if len(foundModels) == len(models) {
  106. return models, nil, nil
  107. }
  108. foundModelsMap := make(map[string]struct{}, len(foundModels))
  109. for _, model := range foundModels {
  110. foundModelsMap[model] = struct{}{}
  111. }
  112. if len(models)-len(foundModels) > 0 {
  113. missingModels := make([]string, 0, len(models)-len(foundModels))
  114. for _, model := range models {
  115. if _, exists := foundModelsMap[model]; !exists {
  116. missingModels = append(missingModels, model)
  117. }
  118. }
  119. return foundModels, missingModels, nil
  120. }
  121. return foundModels, nil, nil
  122. }
  123. func CheckModelConfigExist(models []string) error {
  124. _, missingModels, err := GetModelConfigWithModels(models)
  125. if err != nil {
  126. return err
  127. }
  128. if len(missingModels) > 0 {
  129. slices.Sort(missingModels)
  130. return fmt.Errorf("model config not found: %v", missingModels)
  131. }
  132. return nil
  133. }
  134. func (c *Channel) MarshalJSON() ([]byte, error) {
  135. type Alias Channel
  136. return sonic.Marshal(&struct {
  137. *Alias
  138. CreatedAt int64 `json:"created_at"`
  139. BalanceUpdatedAt int64 `json:"balance_updated_at"`
  140. LastTestErrorAt int64 `json:"last_test_error_at"`
  141. }{
  142. Alias: (*Alias)(c),
  143. CreatedAt: c.CreatedAt.UnixMilli(),
  144. BalanceUpdatedAt: c.BalanceUpdatedAt.UnixMilli(),
  145. LastTestErrorAt: c.LastTestErrorAt.UnixMilli(),
  146. })
  147. }
  148. //nolint:goconst
  149. func getChannelOrder(order string) string {
  150. prefix, suffix, _ := strings.Cut(order, "-")
  151. switch prefix {
  152. case "name",
  153. "type",
  154. "created_at",
  155. "status",
  156. "test_at",
  157. "balance_updated_at",
  158. "used_amount",
  159. "request_count",
  160. "priority",
  161. "id":
  162. switch suffix {
  163. case "asc":
  164. return prefix + " asc"
  165. default:
  166. return prefix + " desc"
  167. }
  168. default:
  169. return "id desc"
  170. }
  171. }
  172. func GetAllChannels() (channels []*Channel, err error) {
  173. tx := DB.Model(&Channel{})
  174. err = tx.Order("id desc").Find(&channels).Error
  175. return channels, err
  176. }
  177. func GetChannels(
  178. page, perPage, id int,
  179. name, key string,
  180. channelType int,
  181. baseURL, order string,
  182. ) (channels []*Channel, total int64, err error) {
  183. tx := DB.Model(&Channel{})
  184. if id != 0 {
  185. tx = tx.Where("id = ?", id)
  186. }
  187. if name != "" {
  188. tx = tx.Where("name = ?", name)
  189. }
  190. if key != "" {
  191. tx = tx.Where("key = ?", key)
  192. }
  193. if channelType != 0 {
  194. tx = tx.Where("type = ?", channelType)
  195. }
  196. if baseURL != "" {
  197. tx = tx.Where("base_url = ?", baseURL)
  198. }
  199. err = tx.Count(&total).Error
  200. if err != nil {
  201. return nil, 0, err
  202. }
  203. if total <= 0 {
  204. return nil, 0, nil
  205. }
  206. limit, offset := toLimitOffset(page, perPage)
  207. err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
  208. return channels, total, err
  209. }
  210. func SearchChannels(
  211. keyword string,
  212. page, perPage, id int,
  213. name, key string,
  214. channelType int,
  215. baseURL, order string,
  216. ) (channels []*Channel, total int64, err error) {
  217. tx := DB.Model(&Channel{})
  218. // Handle exact match conditions for non-zero values
  219. if id != 0 {
  220. tx = tx.Where("id = ?", id)
  221. }
  222. if name != "" {
  223. tx = tx.Where("name = ?", name)
  224. }
  225. if key != "" {
  226. tx = tx.Where("key = ?", key)
  227. }
  228. if channelType != 0 {
  229. tx = tx.Where("type = ?", channelType)
  230. }
  231. if baseURL != "" {
  232. tx = tx.Where("base_url = ?", baseURL)
  233. }
  234. // Handle keyword search for zero value fields
  235. if keyword != "" {
  236. var (
  237. conditions []string
  238. values []any
  239. )
  240. keywordInt := String2Int(keyword)
  241. if keywordInt != 0 {
  242. if id == 0 {
  243. conditions = append(conditions, "id = ?")
  244. values = append(values, keywordInt)
  245. }
  246. }
  247. if name == "" {
  248. if !common.UsingSQLite {
  249. conditions = append(conditions, "name ILIKE ?")
  250. } else {
  251. conditions = append(conditions, "name LIKE ?")
  252. }
  253. values = append(values, "%"+keyword+"%")
  254. }
  255. if key == "" {
  256. if !common.UsingSQLite {
  257. conditions = append(conditions, "key ILIKE ?")
  258. } else {
  259. conditions = append(conditions, "key LIKE ?")
  260. }
  261. values = append(values, "%"+keyword+"%")
  262. }
  263. if baseURL == "" {
  264. if !common.UsingSQLite {
  265. conditions = append(conditions, "base_url ILIKE ?")
  266. } else {
  267. conditions = append(conditions, "base_url LIKE ?")
  268. }
  269. values = append(values, "%"+keyword+"%")
  270. }
  271. if !common.UsingSQLite {
  272. conditions = append(conditions, "models ILIKE ?")
  273. } else {
  274. conditions = append(conditions, "models LIKE ?")
  275. }
  276. values = append(values, "%"+keyword+"%")
  277. if !common.UsingSQLite {
  278. conditions = append(conditions, "sets ILIKE ?")
  279. } else {
  280. conditions = append(conditions, "sets LIKE ?")
  281. }
  282. values = append(values, "%"+keyword+"%")
  283. if len(conditions) > 0 {
  284. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  285. }
  286. }
  287. err = tx.Count(&total).Error
  288. if err != nil {
  289. return nil, 0, err
  290. }
  291. if total <= 0 {
  292. return nil, 0, nil
  293. }
  294. limit, offset := toLimitOffset(page, perPage)
  295. err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
  296. return channels, total, err
  297. }
  298. func GetChannelByID(id int) (*Channel, error) {
  299. channel := Channel{ID: id}
  300. err := DB.First(&channel, "id = ?", id).Error
  301. return &channel, HandleNotFound(err, ErrChannelNotFound)
  302. }
  303. func BatchInsertChannels(channels []*Channel) (err error) {
  304. defer func() {
  305. if err == nil {
  306. _ = InitModelConfigAndChannelCache()
  307. }
  308. }()
  309. for _, channel := range channels {
  310. if err := CheckModelConfigExist(channel.Models); err != nil {
  311. return err
  312. }
  313. }
  314. return DB.Transaction(func(tx *gorm.DB) error {
  315. return tx.Create(&channels).Error
  316. })
  317. }
  318. func UpdateChannel(channel *Channel) (err error) {
  319. defer func() {
  320. if err == nil {
  321. _ = InitModelConfigAndChannelCache()
  322. _ = monitor.ClearChannelAllModelErrors(context.Background(), channel.ID)
  323. }
  324. }()
  325. if err := CheckModelConfigExist(channel.Models); err != nil {
  326. return err
  327. }
  328. selects := []string{
  329. "model_mapping",
  330. "key",
  331. "base_url",
  332. "models",
  333. "priority",
  334. "config",
  335. "enabled_auto_balance_check",
  336. "balance_threshold",
  337. "sets",
  338. }
  339. if channel.Type != 0 {
  340. selects = append(selects, "type")
  341. }
  342. if channel.Name != "" {
  343. selects = append(selects, "name")
  344. }
  345. result := DB.
  346. Select(selects).
  347. Clauses(clause.Returning{}).
  348. Where("id = ?", channel.ID).
  349. Updates(channel)
  350. return HandleUpdateResult(result, ErrChannelNotFound)
  351. }
  352. func ClearLastTestErrorAt(id int) error {
  353. result := DB.Model(&Channel{}).
  354. Where("id = ?", id).
  355. Update("last_test_error_at", gorm.Expr("NULL"))
  356. return HandleUpdateResult(result, ErrChannelNotFound)
  357. }
  358. func (c *Channel) UpdateModelTest(
  359. testAt time.Time,
  360. model, actualModel string,
  361. mode mode.Mode,
  362. took float64,
  363. success bool,
  364. response string,
  365. code int,
  366. ) (*ChannelTest, error) {
  367. var ct *ChannelTest
  368. err := DB.Transaction(func(tx *gorm.DB) error {
  369. if !success {
  370. result := tx.Model(&Channel{}).
  371. Where("id = ?", c.ID).
  372. Update("last_test_error_at", testAt)
  373. if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
  374. return err
  375. }
  376. } else if !c.LastTestErrorAt.IsZero() && time.Since(c.LastTestErrorAt) > time.Hour {
  377. result := tx.Model(&Channel{}).Where("id = ?", c.ID).Update("last_test_error_at", gorm.Expr("NULL"))
  378. if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
  379. return err
  380. }
  381. }
  382. ct = &ChannelTest{
  383. ChannelID: c.ID,
  384. ChannelType: c.Type,
  385. ChannelName: c.Name,
  386. Model: model,
  387. ActualModel: actualModel,
  388. Mode: mode,
  389. TestAt: testAt,
  390. Took: took,
  391. Success: success,
  392. Response: response,
  393. Code: code,
  394. }
  395. result := tx.Save(ct)
  396. return HandleUpdateResult(result, ErrChannelNotFound)
  397. })
  398. if err != nil {
  399. return nil, err
  400. }
  401. return ct, nil
  402. }
  403. func (c *Channel) UpdateBalance(balance float64) error {
  404. result := DB.Model(&Channel{}).
  405. Select("balance_updated_at", "balance").
  406. Where("id = ?", c.ID).
  407. Updates(Channel{
  408. BalanceUpdatedAt: time.Now(),
  409. Balance: balance,
  410. })
  411. return HandleUpdateResult(result, ErrChannelNotFound)
  412. }
  413. func DeleteChannelByID(id int) (err error) {
  414. defer func() {
  415. if err == nil {
  416. _ = InitModelConfigAndChannelCache()
  417. _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
  418. }
  419. }()
  420. result := DB.Delete(&Channel{ID: id})
  421. return HandleUpdateResult(result, ErrChannelNotFound)
  422. }
  423. func DeleteChannelsByIDs(ids []int) (err error) {
  424. defer func() {
  425. if err == nil {
  426. _ = InitModelConfigAndChannelCache()
  427. for _, id := range ids {
  428. _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
  429. }
  430. }
  431. }()
  432. return DB.Transaction(func(tx *gorm.DB) error {
  433. return tx.
  434. Where("id IN (?)", ids).
  435. Delete(&Channel{}).
  436. Error
  437. })
  438. }
  439. func UpdateChannelStatusByID(id, status int) error {
  440. result := DB.Model(&Channel{}).
  441. Where("id = ?", id).
  442. Update("status", status)
  443. return HandleUpdateResult(result, ErrChannelNotFound)
  444. }
  445. func UpdateChannelUsedAmount(id int, amount float64, requestCount, retryCount int) error {
  446. result := DB.Model(&Channel{}).
  447. Where("id = ?", id).
  448. Updates(map[string]any{
  449. "used_amount": gorm.Expr("used_amount + ?", amount),
  450. "request_count": gorm.Expr("request_count + ?", requestCount),
  451. "retry_count": gorm.Expr("retry_count + ?", retryCount),
  452. })
  453. return HandleUpdateResult(result, ErrChannelNotFound)
  454. }