sqlite.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package dataprovider
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "path/filepath"
  6. "strings"
  7. "github.com/drakkan/sftpgo/logger"
  8. "github.com/drakkan/sftpgo/utils"
  9. )
  10. const (
  11. sqliteUsersTableSQL = `CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255)
  12. NOT NULL UNIQUE, "password" varchar(255) NULL, "public_keys" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL,
  13. "gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
  14. "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL,
  15. "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL,
  16. "expiration_date" bigint NOT NULL, "last_login" bigint NOT NULL, "status" integer NOT NULL, "filters" text NULL,
  17. "filesystem" text NULL);`
  18. sqliteSchemaTableSQL = `CREATE TABLE "schema_version" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);`
  19. sqliteUsersV2SQL = `ALTER TABLE "{{users}}" ADD COLUMN "virtual_folders" text NULL;`
  20. sqliteUsersV3SQL = `CREATE TABLE "new__users" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE,
  21. "password" text NULL, "public_keys" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL,
  22. "gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
  23. "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL,
  24. "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "expiration_date" bigint NOT NULL, "last_login" bigint NOT NULL,
  25. "status" integer NOT NULL, "filters" text NULL, "filesystem" text NULL, "virtual_folders" text NULL);
  26. INSERT INTO "new__users" ("id", "username", "public_keys", "home_dir", "uid", "gid", "max_sessions", "quota_size", "quota_files",
  27. "permissions", "used_quota_size", "used_quota_files", "last_quota_update", "upload_bandwidth", "download_bandwidth", "expiration_date",
  28. "last_login", "status", "filters", "filesystem", "virtual_folders", "password") SELECT "id", "username", "public_keys", "home_dir",
  29. "uid", "gid", "max_sessions", "quota_size", "quota_files", "permissions", "used_quota_size", "used_quota_files", "last_quota_update",
  30. "upload_bandwidth", "download_bandwidth", "expiration_date", "last_login", "status", "filters", "filesystem", "virtual_folders",
  31. "password" FROM "{{users}}";
  32. DROP TABLE "{{users}}";
  33. ALTER TABLE "new__users" RENAME TO "{{users}}";`
  34. )
  35. // SQLiteProvider auth provider for SQLite database
  36. type SQLiteProvider struct {
  37. dbHandle *sql.DB
  38. }
  39. func initializeSQLiteProvider(basePath string) error {
  40. var err error
  41. var connectionString string
  42. logSender = fmt.Sprintf("dataprovider_%v", SQLiteDataProviderName)
  43. if len(config.ConnectionString) == 0 {
  44. dbPath := config.Name
  45. if !utils.IsFileInputValid(dbPath) {
  46. return fmt.Errorf("Invalid database path: %#v", dbPath)
  47. }
  48. if !filepath.IsAbs(dbPath) {
  49. dbPath = filepath.Join(basePath, dbPath)
  50. }
  51. connectionString = fmt.Sprintf("file:%v?cache=shared", dbPath)
  52. } else {
  53. connectionString = config.ConnectionString
  54. }
  55. dbHandle, err := sql.Open("sqlite3", connectionString)
  56. if err == nil {
  57. providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %#v", connectionString)
  58. dbHandle.SetMaxOpenConns(1)
  59. provider = SQLiteProvider{dbHandle: dbHandle}
  60. } else {
  61. providerLog(logger.LevelWarn, "error creating sqlite database handler, connection string: %#v, error: %v",
  62. connectionString, err)
  63. }
  64. return err
  65. }
  66. func (p SQLiteProvider) checkAvailability() error {
  67. return sqlCommonCheckAvailability(p.dbHandle)
  68. }
  69. func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) {
  70. return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
  71. }
  72. func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
  73. return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
  74. }
  75. func (p SQLiteProvider) getUserByID(ID int64) (User, error) {
  76. return sqlCommonGetUserByID(ID, p.dbHandle)
  77. }
  78. func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
  79. return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
  80. }
  81. func (p SQLiteProvider) updateLastLogin(username string) error {
  82. return sqlCommonUpdateLastLogin(username, p.dbHandle)
  83. }
  84. func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
  85. return sqlCommonGetUsedQuota(username, p.dbHandle)
  86. }
  87. func (p SQLiteProvider) userExists(username string) (User, error) {
  88. return sqlCommonCheckUserExists(username, p.dbHandle)
  89. }
  90. func (p SQLiteProvider) addUser(user User) error {
  91. return sqlCommonAddUser(user, p.dbHandle)
  92. }
  93. func (p SQLiteProvider) updateUser(user User) error {
  94. return sqlCommonUpdateUser(user, p.dbHandle)
  95. }
  96. func (p SQLiteProvider) deleteUser(user User) error {
  97. return sqlCommonDeleteUser(user, p.dbHandle)
  98. }
  99. func (p SQLiteProvider) dumpUsers() ([]User, error) {
  100. return sqlCommonDumpUsers(p.dbHandle)
  101. }
  102. func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
  103. return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
  104. }
  105. func (p SQLiteProvider) close() error {
  106. return p.dbHandle.Close()
  107. }
  108. func (p SQLiteProvider) reloadConfig() error {
  109. return nil
  110. }
  111. // initializeDatabase creates the initial database structure
  112. func (p SQLiteProvider) initializeDatabase() error {
  113. sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", config.UsersTable, 1)
  114. sql := sqlUsers + " " + sqliteSchemaTableSQL + " " + initialDBVersionSQL
  115. _, err := p.dbHandle.Exec(sql)
  116. return err
  117. }
  118. func (p SQLiteProvider) migrateDatabase() error {
  119. dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle)
  120. if err != nil {
  121. return err
  122. }
  123. if dbVersion.Version == sqlDatabaseVersion {
  124. providerLog(logger.LevelDebug, "sql database is updated, current version: %v", dbVersion.Version)
  125. return nil
  126. }
  127. switch dbVersion.Version {
  128. case 1:
  129. err = updateSQLiteDatabaseFrom1To2(p.dbHandle)
  130. if err != nil {
  131. return err
  132. }
  133. return updateSQLiteDatabaseFrom2To3(p.dbHandle)
  134. case 2:
  135. return updateSQLiteDatabaseFrom2To3(p.dbHandle)
  136. default:
  137. return fmt.Errorf("Database version not handled: %v", dbVersion.Version)
  138. }
  139. }
  140. func updateSQLiteDatabaseFrom1To2(dbHandle *sql.DB) error {
  141. providerLog(logger.LevelInfo, "updating database version: 1 -> 2")
  142. sql := strings.Replace(sqliteUsersV2SQL, "{{users}}", config.UsersTable, 1)
  143. _, err := dbHandle.Exec(sql)
  144. if err != nil {
  145. return err
  146. }
  147. return sqlCommonUpdateDatabaseVersion(dbHandle, 2)
  148. }
  149. func updateSQLiteDatabaseFrom2To3(dbHandle *sql.DB) error {
  150. providerLog(logger.LevelInfo, "updating database version: 2 -> 3")
  151. sql := strings.ReplaceAll(sqliteUsersV3SQL, "{{users}}", config.UsersTable)
  152. _, err := dbHandle.Exec(sql)
  153. if err != nil {
  154. return err
  155. }
  156. return sqlCommonUpdateDatabaseVersion(dbHandle, 3)
  157. }