main.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. package model
  2. import (
  3. "fmt"
  4. "os"
  5. "path/filepath"
  6. "strings"
  7. "time"
  8. "github.com/glebarez/sqlite"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/common/env"
  12. // import fastjson serializer
  13. _ "github.com/labring/aiproxy/core/common/fastJSONSerializer"
  14. "github.com/labring/aiproxy/core/common/notify"
  15. log "github.com/sirupsen/logrus"
  16. "gorm.io/driver/mysql"
  17. "gorm.io/driver/postgres"
  18. "gorm.io/gorm"
  19. gormLogger "gorm.io/gorm/logger"
  20. )
  21. var (
  22. DB *gorm.DB
  23. LogDB *gorm.DB
  24. )
  25. func chooseDB(envName string) (*gorm.DB, error) {
  26. dsn := os.Getenv(envName)
  27. switch {
  28. case strings.HasPrefix(dsn, "postgres"):
  29. // Use PostgreSQL
  30. log.Info("using PostgreSQL as database")
  31. return OpenPostgreSQL(dsn)
  32. default:
  33. // Use SQLite
  34. absPath, err := filepath.Abs(common.SQLitePath)
  35. if err != nil {
  36. return nil, fmt.Errorf("failed to get absolute path of SQLite database: %w", err)
  37. }
  38. log.Info("SQL_DSN not set, using SQLite as database: ", absPath)
  39. common.UsingSQLite = true
  40. return OpenSQLite(absPath)
  41. }
  42. }
  43. func newDBLogger() gormLogger.Interface {
  44. var logLevel gormLogger.LogLevel
  45. if config.DebugSQLEnabled {
  46. logLevel = gormLogger.Info
  47. } else {
  48. logLevel = gormLogger.Warn
  49. }
  50. return gormLogger.New(
  51. log.StandardLogger(),
  52. gormLogger.Config{
  53. SlowThreshold: time.Second,
  54. LogLevel: logLevel,
  55. IgnoreRecordNotFoundError: true,
  56. ParameterizedQueries: !config.DebugSQLEnabled,
  57. Colorful: common.NeedColor(),
  58. },
  59. )
  60. }
  61. func OpenPostgreSQL(dsn string) (*gorm.DB, error) {
  62. return gorm.Open(postgres.New(postgres.Config{
  63. DSN: dsn,
  64. PreferSimpleProtocol: true, // disables implicit prepared statement usage
  65. }), &gorm.Config{
  66. PrepareStmt: true, // precompile SQL
  67. TranslateError: true,
  68. Logger: newDBLogger(),
  69. DisableForeignKeyConstraintWhenMigrating: false,
  70. IgnoreRelationshipsWhenMigrating: false,
  71. })
  72. }
  73. func OpenMySQL(dsn string) (*gorm.DB, error) {
  74. return gorm.Open(mysql.New(mysql.Config{
  75. DSN: strings.TrimPrefix(dsn, "mysql://"),
  76. }), &gorm.Config{
  77. PrepareStmt: true, // precompile SQL
  78. TranslateError: true,
  79. Logger: newDBLogger(),
  80. DisableForeignKeyConstraintWhenMigrating: false,
  81. IgnoreRelationshipsWhenMigrating: false,
  82. })
  83. }
  84. func OpenSQLite(sqlitePath string) (*gorm.DB, error) {
  85. baseDir := filepath.Dir(sqlitePath)
  86. if err := os.MkdirAll(baseDir, 0o755); err != nil {
  87. return nil, fmt.Errorf("failed to create base directory: %w", err)
  88. }
  89. dsn := fmt.Sprintf("%s?_busy_timeout=%d", sqlitePath, common.SQLiteBusyTimeout)
  90. return gorm.Open(sqlite.Open(dsn), &gorm.Config{
  91. PrepareStmt: true, // precompile SQL
  92. TranslateError: true,
  93. Logger: newDBLogger(),
  94. DisableForeignKeyConstraintWhenMigrating: false,
  95. IgnoreRelationshipsWhenMigrating: false,
  96. })
  97. }
  98. func InitDB() error {
  99. var err error
  100. DB, err = chooseDB("SQL_DSN")
  101. if err != nil {
  102. return fmt.Errorf("failed to initialize database: %w", err)
  103. }
  104. setDBConns(DB)
  105. if config.DisableAutoMigrateDB {
  106. return nil
  107. }
  108. log.Info("database migration started")
  109. if err = migrateDB(); err != nil {
  110. log.Fatal("failed to migrate database: " + err.Error())
  111. return fmt.Errorf("failed to migrate database: %w", err)
  112. }
  113. log.Info("database migrated")
  114. return nil
  115. }
  116. func migrateDB() error {
  117. err := DB.AutoMigrate(
  118. &Channel{},
  119. &ChannelTest{},
  120. &Token{},
  121. &PublicMCP{},
  122. &GroupModelConfig{},
  123. &PublicMCPReusingParam{},
  124. &GroupMCP{},
  125. &Group{},
  126. &Option{},
  127. &ModelConfig{},
  128. )
  129. if err != nil {
  130. return err
  131. }
  132. return nil
  133. }
  134. func InitLogDB(batchSize int) error {
  135. if os.Getenv("LOG_SQL_DSN") == "" {
  136. LogDB = DB
  137. } else {
  138. log.Info("using log database for table logs")
  139. var err error
  140. LogDB, err = chooseDB("LOG_SQL_DSN")
  141. if err != nil {
  142. return fmt.Errorf("failed to initialize log database: %w", err)
  143. }
  144. setDBConns(LogDB)
  145. }
  146. if config.DisableAutoMigrateDB {
  147. return nil
  148. }
  149. log.Info("log database migration started")
  150. err := migrateLogDB(batchSize)
  151. if err != nil {
  152. // ignore migrate log error when use double database
  153. if LogDB == DB {
  154. return fmt.Errorf("failed to migrate log database: %w", err)
  155. }
  156. log.Errorf("failed to migrate log database: %v", err)
  157. log.Warn("log database migration with backend started")
  158. go migrateLogDBBackend(batchSize)
  159. } else {
  160. log.Info("log database migrated")
  161. }
  162. return nil
  163. }
  164. func migrateLogDBBackend(batchSize int) {
  165. ticker := time.NewTicker(time.Minute)
  166. defer ticker.Stop()
  167. for range ticker.C {
  168. err := migrateLogDB(batchSize)
  169. if err == nil {
  170. return
  171. }
  172. log.Errorf("failed to migrate log database: %v", err)
  173. ticker.Reset(time.Minute)
  174. }
  175. }
  176. func migrateLogDB(batchSize int) error {
  177. // Pre-migration cleanup to remove expired data
  178. err := preMigrationCleanup(batchSize)
  179. if err != nil {
  180. log.Warn("failed to perform pre-migration cleanup: ", err.Error())
  181. }
  182. err = LogDB.AutoMigrate(
  183. &Log{},
  184. &RequestDetail{},
  185. &RetryLog{},
  186. &GroupSummary{},
  187. &Summary{},
  188. &ConsumeError{},
  189. &StoreV2{},
  190. &SummaryMinute{},
  191. &GroupSummaryMinute{},
  192. )
  193. if err != nil {
  194. return err
  195. }
  196. go func() {
  197. err := CreateLogIndexes(LogDB)
  198. if err != nil {
  199. notify.ErrorThrottle(
  200. "createLogIndexes",
  201. time.Minute*10,
  202. "failed to create log indexes",
  203. err.Error(),
  204. )
  205. }
  206. err = CreateSummaryIndexs(LogDB)
  207. if err != nil {
  208. notify.ErrorThrottle(
  209. "createSummaryIndexs",
  210. time.Minute*10,
  211. "failed to create summary indexs",
  212. err.Error(),
  213. )
  214. }
  215. err = CreateGroupSummaryIndexs(LogDB)
  216. if err != nil {
  217. notify.ErrorThrottle(
  218. "createGroupSummaryIndexs",
  219. time.Minute*10,
  220. "failed to create group summary indexs",
  221. err.Error(),
  222. )
  223. }
  224. err = CreateSummaryMinuteIndexs(LogDB)
  225. if err != nil {
  226. notify.ErrorThrottle(
  227. "createSummaryMinuteIndexs",
  228. time.Minute*10,
  229. "failed to create summary minute indexs",
  230. err.Error(),
  231. )
  232. }
  233. err = CreateGroupSummaryMinuteIndexs(LogDB)
  234. if err != nil {
  235. notify.ErrorThrottle(
  236. "createSummaryMinuteIndexs",
  237. time.Minute*10,
  238. "failed to create group summary minute indexs",
  239. err.Error(),
  240. )
  241. }
  242. }()
  243. return nil
  244. }
  245. func setDBConns(db *gorm.DB) {
  246. if config.DebugSQLEnabled {
  247. db = db.Debug()
  248. }
  249. sqlDB, err := db.DB()
  250. if err != nil {
  251. log.Fatal("failed to connect database: " + err.Error())
  252. return
  253. }
  254. sqlDB.SetMaxIdleConns(int(env.Int64("SQL_MAX_IDLE_CONNS", 100)))
  255. sqlDB.SetMaxOpenConns(int(env.Int64("SQL_MAX_OPEN_CONNS", 1000)))
  256. sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int64("SQL_MAX_LIFETIME", 60)))
  257. }
  258. func closeDB(db *gorm.DB) error {
  259. sqlDB, err := db.DB()
  260. if err != nil {
  261. return err
  262. }
  263. err = sqlDB.Close()
  264. return err
  265. }
  266. func CloseDB() error {
  267. if LogDB != DB {
  268. err := closeDB(LogDB)
  269. if err != nil {
  270. return err
  271. }
  272. }
  273. return closeDB(DB)
  274. }
  275. func ignoreNoSuchTable(err error) bool {
  276. message := err.Error()
  277. return strings.Contains(message, "no such table") ||
  278. strings.Contains(message, "does not exist")
  279. }
  280. // preMigrationCleanup cleans up expired logs and request details before migration
  281. // to reduce database size and improve migration performance
  282. func preMigrationCleanup(batchSize int) error {
  283. log.Info("starting pre-migration cleanup of expired data")
  284. // Clean up logs
  285. err := preMigrationCleanupLogs(batchSize)
  286. if err != nil {
  287. if ignoreNoSuchTable(err) {
  288. return nil
  289. }
  290. return fmt.Errorf("failed to cleanup logs: %w", err)
  291. }
  292. // Clean up retry logs
  293. err = preMigrationCleanupRetryLogs(batchSize)
  294. if err != nil {
  295. if ignoreNoSuchTable(err) {
  296. return nil
  297. }
  298. return fmt.Errorf("failed to cleanup retry logs: %w", err)
  299. }
  300. // Clean up request details
  301. err = preMigrationCleanupRequestDetails(batchSize)
  302. if err != nil {
  303. if ignoreNoSuchTable(err) {
  304. return nil
  305. }
  306. return fmt.Errorf("failed to cleanup request details: %w", err)
  307. }
  308. log.Info("pre-migration cleanup completed")
  309. return nil
  310. }
  311. // preMigrationCleanupLogs cleans up expired logs using ID-based batch deletion
  312. func preMigrationCleanupLogs(batchSize int) error {
  313. logStorageHours := config.GetLogStorageHours()
  314. if logStorageHours == 0 {
  315. return nil
  316. }
  317. if batchSize <= 0 {
  318. batchSize = defaultCleanLogBatchSize
  319. }
  320. cutoffTime := time.Now().Add(-time.Duration(logStorageHours) * time.Hour)
  321. // First, get the IDs to delete
  322. ids := make([]int, 0, batchSize)
  323. for {
  324. ids = ids[:0]
  325. err := LogDB.Model(&Log{}).
  326. Select("id").
  327. Where("created_at < ?", cutoffTime).
  328. Limit(batchSize).
  329. Find(&ids).Error
  330. if err != nil {
  331. return err
  332. }
  333. // If no IDs found, we're done
  334. if len(ids) == 0 {
  335. break
  336. }
  337. // Delete by IDs
  338. err = LogDB.Where("id IN (?)", ids).
  339. Session(&gorm.Session{SkipDefaultTransaction: true}).
  340. Delete(&Log{}).Error
  341. if err != nil {
  342. return err
  343. }
  344. log.Infof("deleted %d expired log records", len(ids))
  345. // If we got less than batchSize, we're done
  346. if len(ids) < batchSize {
  347. break
  348. }
  349. }
  350. return nil
  351. }
  352. // preMigrationCleanupRetryLogs cleans up expired logs using ID-based batch deletion
  353. func preMigrationCleanupRetryLogs(batchSize int) error {
  354. logStorageHours := config.GetRetryLogStorageHours()
  355. if logStorageHours == 0 {
  356. logStorageHours = config.GetLogStorageHours()
  357. }
  358. if logStorageHours == 0 {
  359. return nil
  360. }
  361. if batchSize <= 0 {
  362. batchSize = defaultCleanLogBatchSize
  363. }
  364. cutoffTime := time.Now().Add(-time.Duration(logStorageHours) * time.Hour)
  365. // First, get the IDs to delete
  366. ids := make([]int, 0, batchSize)
  367. for {
  368. ids = ids[:0]
  369. err := LogDB.Model(&RetryLog{}).
  370. Select("id").
  371. Where("created_at < ?", cutoffTime).
  372. Limit(batchSize).
  373. Find(&ids).Error
  374. if err != nil {
  375. return err
  376. }
  377. // If no IDs found, we're done
  378. if len(ids) == 0 {
  379. break
  380. }
  381. // Delete by IDs
  382. err = LogDB.Where("id IN (?)", ids).
  383. Session(&gorm.Session{SkipDefaultTransaction: true}).
  384. Delete(&Log{}).Error
  385. if err != nil {
  386. return err
  387. }
  388. log.Infof("deleted %d expired retry log records", len(ids))
  389. // If we got less than batchSize, we're done
  390. if len(ids) < batchSize {
  391. break
  392. }
  393. }
  394. return nil
  395. }
  396. // preMigrationCleanupRequestDetails cleans up expired request details using ID-based batch deletion
  397. func preMigrationCleanupRequestDetails(batchSize int) error {
  398. detailStorageHours := config.GetLogDetailStorageHours()
  399. if detailStorageHours == 0 {
  400. detailStorageHours = config.GetLogStorageHours()
  401. }
  402. if detailStorageHours == 0 {
  403. return nil
  404. }
  405. if batchSize <= 0 {
  406. batchSize = defaultCleanLogBatchSize
  407. }
  408. cutoffTime := time.Now().Add(-time.Duration(detailStorageHours) * time.Hour)
  409. // First, get the IDs to delete
  410. ids := make([]int, 0, batchSize)
  411. for {
  412. ids = ids[:0]
  413. err := LogDB.Model(&RequestDetail{}).
  414. Select("id").
  415. Where("created_at < ?", cutoffTime).
  416. Limit(batchSize).
  417. Find(&ids).Error
  418. if err != nil {
  419. return err
  420. }
  421. // If no IDs found, we're done
  422. if len(ids) == 0 {
  423. break
  424. }
  425. // Delete by IDs
  426. err = LogDB.Where("id IN (?)", ids).
  427. Session(&gorm.Session{SkipDefaultTransaction: true}).
  428. Delete(&RequestDetail{}).Error
  429. if err != nil {
  430. return err
  431. }
  432. log.Infof("deleted %d expired request detail records", len(ids))
  433. // If we got less than batchSize, we're done
  434. if len(ids) < batchSize {
  435. break
  436. }
  437. }
  438. return nil
  439. }