channel.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. package model
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "slices"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/bytedance/sonic/ast"
  11. "github.com/labring/aiproxy/core/common"
  12. "github.com/labring/aiproxy/core/common/config"
  13. "github.com/labring/aiproxy/core/monitor"
  14. "github.com/labring/aiproxy/core/relay/mode"
  15. "gorm.io/gorm"
  16. "gorm.io/gorm/clause"
  17. )
  18. const (
  19. ErrChannelNotFound = "channel"
  20. )
  21. const (
  22. ChannelStatusUnknown = 0
  23. ChannelStatusEnabled = 1
  24. ChannelStatusDisabled = 2
  25. )
  26. const (
  27. ChannelDefaultSet = "default"
  28. )
  29. type ChannelConfig struct {
  30. Spec json.RawMessage `json:"spec"`
  31. }
  32. // validate spec json is map[string]any
  33. func (c *ChannelConfig) UnmarshalJSON(data []byte) error {
  34. type Alias ChannelConfig
  35. alias := (*Alias)(c)
  36. if err := sonic.Unmarshal(data, alias); err != nil {
  37. return err
  38. }
  39. if len(alias.Spec) > 0 {
  40. var spec map[string]any
  41. if err := sonic.Unmarshal(alias.Spec, &spec); err != nil {
  42. return fmt.Errorf("invalid spec json: %w", err)
  43. }
  44. }
  45. return nil
  46. }
  47. func (c *ChannelConfig) SpecConfig(obj any) error {
  48. if c == nil || len(c.Spec) == 0 {
  49. return nil
  50. }
  51. return sonic.Unmarshal(c.Spec, obj)
  52. }
  53. func (c *ChannelConfig) Get(key ...any) (ast.Node, error) {
  54. if c == nil || len(c.Spec) == 0 {
  55. return ast.Node{}, ast.ErrNotExist
  56. }
  57. return sonic.Get(c.Spec, key...)
  58. }
  59. type Channel struct {
  60. DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
  61. CreatedAt time.Time `gorm:"index" json:"created_at"`
  62. LastTestErrorAt time.Time ` json:"last_test_error_at"`
  63. ChannelTests []*ChannelTest `gorm:"foreignKey:ChannelID;references:ID" json:"channel_tests,omitempty"`
  64. BalanceUpdatedAt time.Time ` json:"balance_updated_at"`
  65. ModelMapping map[string]string `gorm:"serializer:fastjson;type:text" json:"model_mapping"`
  66. Key string `gorm:"type:text;index:,length:191" json:"key"`
  67. Name string `gorm:"size:64;index" json:"name"`
  68. BaseURL string `gorm:"size:128;index" json:"base_url"`
  69. Models []string `gorm:"serializer:fastjson;type:text" json:"models"`
  70. Balance float64 ` json:"balance"`
  71. ID int `gorm:"primaryKey" json:"id"`
  72. UsedAmount float64 `gorm:"index" json:"used_amount"`
  73. RequestCount int `gorm:"index" json:"request_count"`
  74. RetryCount int `gorm:"index" json:"retry_count"`
  75. Status int `gorm:"default:1;index" json:"status"`
  76. Type ChannelType `gorm:"default:0;index" json:"type"`
  77. Priority int32 ` json:"priority"`
  78. EnabledAutoBalanceCheck bool ` json:"enabled_auto_balance_check"`
  79. BalanceThreshold float64 ` json:"balance_threshold"`
  80. Config *ChannelConfig `gorm:"serializer:fastjson;type:text" json:"config,omitempty"`
  81. Sets []string `gorm:"serializer:fastjson;type:text" json:"sets,omitempty"`
  82. }
  83. func (c *Channel) GetSets() []string {
  84. if len(c.Sets) == 0 {
  85. return []string{ChannelDefaultSet}
  86. }
  87. return c.Sets
  88. }
  89. func (c *Channel) BeforeDelete(tx *gorm.DB) (err error) {
  90. return tx.Model(&ChannelTest{}).Where("channel_id = ?", c.ID).Delete(&ChannelTest{}).Error
  91. }
  92. func (c *Channel) GetBalanceThreshold() float64 {
  93. if c.BalanceThreshold < 0 {
  94. return 0
  95. }
  96. return c.BalanceThreshold
  97. }
  98. const (
  99. DefaultPriority = 10
  100. )
  101. func (c *Channel) GetPriority() int32 {
  102. if c.Priority == 0 {
  103. return DefaultPriority
  104. }
  105. return c.Priority
  106. }
  107. func GetModelConfigWithModels(models []string) ([]string, []string, error) {
  108. if len(models) == 0 || config.DisableModelConfig {
  109. return models, nil, nil
  110. }
  111. where := DB.Model(&ModelConfig{}).Where("model IN ?", models)
  112. var count int64
  113. if err := where.Count(&count).Error; err != nil {
  114. return nil, nil, err
  115. }
  116. if count == 0 {
  117. return nil, models, nil
  118. }
  119. if count == int64(len(models)) {
  120. return models, nil, nil
  121. }
  122. var foundModels []string
  123. if err := where.Pluck("model", &foundModels).Error; err != nil {
  124. return nil, nil, err
  125. }
  126. if len(foundModels) == len(models) {
  127. return models, nil, nil
  128. }
  129. foundModelsMap := make(map[string]struct{}, len(foundModels))
  130. for _, model := range foundModels {
  131. foundModelsMap[model] = struct{}{}
  132. }
  133. if len(models)-len(foundModels) > 0 {
  134. missingModels := make([]string, 0, len(models)-len(foundModels))
  135. for _, model := range models {
  136. if _, exists := foundModelsMap[model]; !exists {
  137. missingModels = append(missingModels, model)
  138. }
  139. }
  140. return foundModels, missingModels, nil
  141. }
  142. return foundModels, nil, nil
  143. }
  144. func CheckModelConfigExist(models []string) error {
  145. _, missingModels, err := GetModelConfigWithModels(models)
  146. if err != nil {
  147. return err
  148. }
  149. if len(missingModels) > 0 {
  150. slices.Sort(missingModels)
  151. return fmt.Errorf("model config not found: %v", missingModels)
  152. }
  153. return nil
  154. }
  155. func (c *Channel) MarshalJSON() ([]byte, error) {
  156. type Alias Channel
  157. return sonic.Marshal(&struct {
  158. *Alias
  159. CreatedAt int64 `json:"created_at"`
  160. BalanceUpdatedAt int64 `json:"balance_updated_at"`
  161. LastTestErrorAt int64 `json:"last_test_error_at"`
  162. }{
  163. Alias: (*Alias)(c),
  164. CreatedAt: c.CreatedAt.UnixMilli(),
  165. BalanceUpdatedAt: c.BalanceUpdatedAt.UnixMilli(),
  166. LastTestErrorAt: c.LastTestErrorAt.UnixMilli(),
  167. })
  168. }
  169. //nolint:goconst
  170. func getChannelOrder(order string) string {
  171. prefix, suffix, _ := strings.Cut(order, "-")
  172. switch prefix {
  173. case "name",
  174. "type",
  175. "created_at",
  176. "status",
  177. "test_at",
  178. "balance_updated_at",
  179. "used_amount",
  180. "request_count",
  181. "priority",
  182. "id":
  183. switch suffix {
  184. case "asc":
  185. return prefix + " asc"
  186. default:
  187. return prefix + " desc"
  188. }
  189. default:
  190. return "id desc"
  191. }
  192. }
  193. func GetAllChannels() (channels []*Channel, err error) {
  194. tx := DB.Model(&Channel{})
  195. err = tx.Order("id desc").Find(&channels).Error
  196. return channels, err
  197. }
  198. func GetChannels(
  199. page, perPage, id int,
  200. name, key string,
  201. channelType int,
  202. baseURL, order string,
  203. ) (channels []*Channel, total int64, err error) {
  204. tx := DB.Model(&Channel{})
  205. if id != 0 {
  206. tx = tx.Where("id = ?", id)
  207. }
  208. if name != "" {
  209. tx = tx.Where("name = ?", name)
  210. }
  211. if key != "" {
  212. tx = tx.Where("key = ?", key)
  213. }
  214. if channelType != 0 {
  215. tx = tx.Where("type = ?", channelType)
  216. }
  217. if baseURL != "" {
  218. tx = tx.Where("base_url = ?", baseURL)
  219. }
  220. err = tx.Count(&total).Error
  221. if err != nil {
  222. return nil, 0, err
  223. }
  224. if total <= 0 {
  225. return nil, 0, nil
  226. }
  227. limit, offset := toLimitOffset(page, perPage)
  228. err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
  229. return channels, total, err
  230. }
  231. func SearchChannels(
  232. keyword string,
  233. page, perPage, id int,
  234. name, key string,
  235. channelType int,
  236. baseURL, order string,
  237. ) (channels []*Channel, total int64, err error) {
  238. tx := DB.Model(&Channel{})
  239. // Handle exact match conditions for non-zero values
  240. if id != 0 {
  241. tx = tx.Where("id = ?", id)
  242. }
  243. if name != "" {
  244. tx = tx.Where("name = ?", name)
  245. }
  246. if key != "" {
  247. tx = tx.Where("key = ?", key)
  248. }
  249. if channelType != 0 {
  250. tx = tx.Where("type = ?", channelType)
  251. }
  252. if baseURL != "" {
  253. tx = tx.Where("base_url = ?", baseURL)
  254. }
  255. // Handle keyword search for zero value fields
  256. if keyword != "" {
  257. var (
  258. conditions []string
  259. values []any
  260. )
  261. keywordInt := String2Int(keyword)
  262. if keywordInt != 0 {
  263. if id == 0 {
  264. conditions = append(conditions, "id = ?")
  265. values = append(values, keywordInt)
  266. }
  267. }
  268. if name == "" {
  269. if !common.UsingSQLite {
  270. conditions = append(conditions, "name ILIKE ?")
  271. } else {
  272. conditions = append(conditions, "name LIKE ?")
  273. }
  274. values = append(values, "%"+keyword+"%")
  275. }
  276. if key == "" {
  277. if !common.UsingSQLite {
  278. conditions = append(conditions, "key ILIKE ?")
  279. } else {
  280. conditions = append(conditions, "key LIKE ?")
  281. }
  282. values = append(values, "%"+keyword+"%")
  283. }
  284. if baseURL == "" {
  285. if !common.UsingSQLite {
  286. conditions = append(conditions, "base_url ILIKE ?")
  287. } else {
  288. conditions = append(conditions, "base_url LIKE ?")
  289. }
  290. values = append(values, "%"+keyword+"%")
  291. }
  292. if !common.UsingSQLite {
  293. conditions = append(conditions, "models ILIKE ?")
  294. } else {
  295. conditions = append(conditions, "models LIKE ?")
  296. }
  297. values = append(values, "%"+keyword+"%")
  298. if !common.UsingSQLite {
  299. conditions = append(conditions, "sets ILIKE ?")
  300. } else {
  301. conditions = append(conditions, "sets LIKE ?")
  302. }
  303. values = append(values, "%"+keyword+"%")
  304. if len(conditions) > 0 {
  305. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  306. }
  307. }
  308. err = tx.Count(&total).Error
  309. if err != nil {
  310. return nil, 0, err
  311. }
  312. if total <= 0 {
  313. return nil, 0, nil
  314. }
  315. limit, offset := toLimitOffset(page, perPage)
  316. err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
  317. return channels, total, err
  318. }
  319. func GetChannelByID(id int) (*Channel, error) {
  320. channel := Channel{ID: id}
  321. err := DB.First(&channel, "id = ?", id).Error
  322. return &channel, HandleNotFound(err, ErrChannelNotFound)
  323. }
  324. func BatchInsertChannels(channels []*Channel) (err error) {
  325. defer func() {
  326. if err == nil {
  327. _ = InitModelConfigAndChannelCache()
  328. }
  329. }()
  330. for _, channel := range channels {
  331. if err := CheckModelConfigExist(channel.Models); err != nil {
  332. return err
  333. }
  334. }
  335. return DB.Transaction(func(tx *gorm.DB) error {
  336. return tx.Create(&channels).Error
  337. })
  338. }
  339. func UpdateChannel(channel *Channel) (err error) {
  340. defer func() {
  341. if err == nil {
  342. _ = InitModelConfigAndChannelCache()
  343. _ = monitor.ClearChannelAllModelErrors(context.Background(), channel.ID)
  344. }
  345. }()
  346. if err := CheckModelConfigExist(channel.Models); err != nil {
  347. return err
  348. }
  349. selects := []string{
  350. "model_mapping",
  351. "key",
  352. "base_url",
  353. "models",
  354. "priority",
  355. "config",
  356. "enabled_auto_balance_check",
  357. "balance_threshold",
  358. "sets",
  359. }
  360. if channel.Type != 0 {
  361. selects = append(selects, "type")
  362. }
  363. if channel.Name != "" {
  364. selects = append(selects, "name")
  365. }
  366. result := DB.
  367. Select(selects).
  368. Clauses(clause.Returning{}).
  369. Where("id = ?", channel.ID).
  370. Updates(channel)
  371. return HandleUpdateResult(result, ErrChannelNotFound)
  372. }
  373. func ClearLastTestErrorAt(id int) error {
  374. result := DB.Model(&Channel{}).
  375. Where("id = ?", id).
  376. Update("last_test_error_at", gorm.Expr("NULL"))
  377. return HandleUpdateResult(result, ErrChannelNotFound)
  378. }
  379. func (c *Channel) UpdateModelTest(
  380. testAt time.Time,
  381. model, actualModel string,
  382. mode mode.Mode,
  383. took float64,
  384. success bool,
  385. response string,
  386. code int,
  387. ) (*ChannelTest, error) {
  388. var ct *ChannelTest
  389. err := DB.Transaction(func(tx *gorm.DB) error {
  390. if !success {
  391. result := tx.Model(&Channel{}).
  392. Where("id = ?", c.ID).
  393. Update("last_test_error_at", testAt)
  394. if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
  395. return err
  396. }
  397. } else if !c.LastTestErrorAt.IsZero() && time.Since(c.LastTestErrorAt) > time.Hour {
  398. result := tx.Model(&Channel{}).Where("id = ?", c.ID).Update("last_test_error_at", gorm.Expr("NULL"))
  399. if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
  400. return err
  401. }
  402. }
  403. ct = &ChannelTest{
  404. ChannelID: c.ID,
  405. ChannelType: c.Type,
  406. ChannelName: c.Name,
  407. Model: model,
  408. ActualModel: actualModel,
  409. Mode: mode,
  410. TestAt: testAt,
  411. Took: took,
  412. Success: success,
  413. Response: response,
  414. Code: code,
  415. }
  416. result := tx.Save(ct)
  417. return HandleUpdateResult(result, ErrChannelNotFound)
  418. })
  419. if err != nil {
  420. return nil, err
  421. }
  422. return ct, nil
  423. }
  424. func (c *Channel) UpdateBalance(balance float64) error {
  425. result := DB.Model(&Channel{}).
  426. Select("balance_updated_at", "balance").
  427. Where("id = ?", c.ID).
  428. Updates(Channel{
  429. BalanceUpdatedAt: time.Now(),
  430. Balance: balance,
  431. })
  432. return HandleUpdateResult(result, ErrChannelNotFound)
  433. }
  434. func DeleteChannelByID(id int) (err error) {
  435. defer func() {
  436. if err == nil {
  437. _ = InitModelConfigAndChannelCache()
  438. _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
  439. }
  440. }()
  441. result := DB.Delete(&Channel{ID: id})
  442. return HandleUpdateResult(result, ErrChannelNotFound)
  443. }
  444. func DeleteChannelsByIDs(ids []int) (err error) {
  445. defer func() {
  446. if err == nil {
  447. _ = InitModelConfigAndChannelCache()
  448. for _, id := range ids {
  449. _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
  450. }
  451. }
  452. }()
  453. return DB.Transaction(func(tx *gorm.DB) error {
  454. return tx.
  455. Where("id IN (?)", ids).
  456. Delete(&Channel{}).
  457. Error
  458. })
  459. }
  460. func UpdateChannelStatusByID(id, status int) error {
  461. result := DB.Model(&Channel{}).
  462. Where("id = ?", id).
  463. Update("status", status)
  464. return HandleUpdateResult(result, ErrChannelNotFound)
  465. }
  466. func UpdateChannelUsedAmount(id int, amount float64, requestCount, retryCount int) error {
  467. result := DB.Model(&Channel{}).
  468. Where("id = ?", id).
  469. Updates(map[string]any{
  470. "used_amount": gorm.Expr("used_amount + ?", amount),
  471. "request_count": gorm.Expr("request_count + ?", requestCount),
  472. "retry_count": gorm.Expr("retry_count + ?", retryCount),
  473. })
  474. return HandleUpdateResult(result, ErrChannelNotFound)
  475. }