main.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package model
  2. import (
  3. "gorm.io/driver/mysql"
  4. "gorm.io/driver/postgres"
  5. "gorm.io/driver/sqlite"
  6. "gorm.io/gorm"
  7. "log"
  8. "one-api/common"
  9. "os"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var DB *gorm.DB
  15. func createRootAccountIfNeed() error {
  16. var user User
  17. //if user.Status != common.UserStatusEnabled {
  18. if err := DB.First(&user).Error; err != nil {
  19. common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
  20. hashedPassword, err := common.Password2Hash("123456")
  21. if err != nil {
  22. return err
  23. }
  24. rootUser := User{
  25. Username: "root",
  26. Password: hashedPassword,
  27. Role: common.RoleRootUser,
  28. Status: common.UserStatusEnabled,
  29. DisplayName: "Root User",
  30. AccessToken: common.GetUUID(),
  31. Quota: 100000000,
  32. }
  33. DB.Create(&rootUser)
  34. }
  35. return nil
  36. }
  37. func chooseDB() (*gorm.DB, error) {
  38. if os.Getenv("SQL_DSN") != "" {
  39. dsn := os.Getenv("SQL_DSN")
  40. if strings.HasPrefix(dsn, "postgres://") {
  41. // Use PostgreSQL
  42. common.SysLog("using PostgreSQL as database")
  43. common.UsingPostgreSQL = true
  44. return gorm.Open(postgres.New(postgres.Config{
  45. DSN: dsn,
  46. PreferSimpleProtocol: true, // disables implicit prepared statement usage
  47. }), &gorm.Config{
  48. PrepareStmt: true, // precompile SQL
  49. })
  50. }
  51. // Use MySQL
  52. common.SysLog("using MySQL as database")
  53. // check parseTime
  54. if !strings.Contains(dsn, "parseTime") {
  55. if strings.Contains(dsn, "?") {
  56. dsn += "&parseTime=true"
  57. } else {
  58. dsn += "?parseTime=true"
  59. }
  60. }
  61. common.UsingMySQL = true
  62. return gorm.Open(mysql.Open(dsn), &gorm.Config{
  63. PrepareStmt: true, // precompile SQL
  64. })
  65. }
  66. // Use SQLite
  67. common.SysLog("SQL_DSN not set, using SQLite as database")
  68. common.UsingSQLite = true
  69. return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
  70. PrepareStmt: true, // precompile SQL
  71. })
  72. }
  73. func InitDB() (err error) {
  74. db, err := chooseDB()
  75. if err == nil {
  76. if common.DebugEnabled {
  77. db = db.Debug()
  78. }
  79. DB = db
  80. sqlDB, err := DB.DB()
  81. if err != nil {
  82. return err
  83. }
  84. sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
  85. sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
  86. sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
  87. if !common.IsMasterNode {
  88. return nil
  89. }
  90. //if common.UsingMySQL {
  91. // _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
  92. // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
  93. // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
  94. // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
  95. //}
  96. common.SysLog("database migration started")
  97. err = db.AutoMigrate(&Channel{})
  98. if err != nil {
  99. return err
  100. }
  101. err = db.AutoMigrate(&Token{})
  102. if err != nil {
  103. return err
  104. }
  105. err = db.AutoMigrate(&User{})
  106. if err != nil {
  107. return err
  108. }
  109. err = db.AutoMigrate(&Option{})
  110. if err != nil {
  111. return err
  112. }
  113. err = db.AutoMigrate(&Redemption{})
  114. if err != nil {
  115. return err
  116. }
  117. err = db.AutoMigrate(&Ability{})
  118. if err != nil {
  119. return err
  120. }
  121. err = db.AutoMigrate(&Log{})
  122. if err != nil {
  123. return err
  124. }
  125. err = db.AutoMigrate(&Midjourney{})
  126. if err != nil {
  127. return err
  128. }
  129. err = db.AutoMigrate(&TopUp{})
  130. if err != nil {
  131. return err
  132. }
  133. err = db.AutoMigrate(&QuotaData{})
  134. if err != nil {
  135. return err
  136. }
  137. err = db.AutoMigrate(&Task{})
  138. if err != nil {
  139. return err
  140. }
  141. common.SysLog("database migrated")
  142. err = createRootAccountIfNeed()
  143. return err
  144. } else {
  145. common.FatalLog(err)
  146. }
  147. return err
  148. }
  149. func CloseDB() error {
  150. sqlDB, err := DB.DB()
  151. if err != nil {
  152. return err
  153. }
  154. err = sqlDB.Close()
  155. return err
  156. }
  157. var (
  158. lastPingTime time.Time
  159. pingMutex sync.Mutex
  160. )
  161. func PingDB() error {
  162. pingMutex.Lock()
  163. defer pingMutex.Unlock()
  164. if time.Since(lastPingTime) < time.Second*10 {
  165. return nil
  166. }
  167. sqlDB, err := DB.DB()
  168. if err != nil {
  169. log.Printf("Error getting sql.DB from GORM: %v", err)
  170. return err
  171. }
  172. err = sqlDB.Ping()
  173. if err != nil {
  174. log.Printf("Error pinging DB: %v", err)
  175. return err
  176. }
  177. lastPingTime = time.Now()
  178. common.SysLog("Database pinged successfully")
  179. return nil
  180. }