sqlcommon.go 33 KB


  1. package dataprovider
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "strings"
  8. "time"
  9. "github.com/drakkan/sftpgo/logger"
  10. "github.com/drakkan/sftpgo/utils"
  11. "github.com/drakkan/sftpgo/vfs"
  12. )
  13. const (
  14. sqlDatabaseVersion = 6
  15. initialDBVersionSQL = "INSERT INTO {{schema_version}} (version) VALUES (1);"
  16. defaultSQLQueryTimeout = 10 * time.Second
  17. longSQLQueryTimeout = 60 * time.Second
  18. )
  19. var errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user")
  20. type sqlQuerier interface {
  21. PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
  22. }
  23. func getUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
  24. var user User
  25. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  26. defer cancel()
  27. q := getUserByUsernameQuery()
  28. stmt, err := dbHandle.PrepareContext(ctx, q)
  29. if err != nil {
  30. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  31. return user, err
  32. }
  33. defer stmt.Close()
  34. row := stmt.QueryRowContext(ctx, username)
  35. user, err = getUserFromDbRow(row, nil)
  36. if err != nil {
  37. return user, err
  38. }
  39. return getUserWithVirtualFolders(user, dbHandle)
  40. }
  41. func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
  42. var user User
  43. if len(password) == 0 {
  44. return user, errors.New("Credentials cannot be null or empty")
  45. }
  46. user, err := getUserByUsername(username, dbHandle)
  47. if err != nil {
  48. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  49. return user, err
  50. }
  51. return checkUserAndPass(user, password, ip, protocol)
  52. }
  53. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
  54. var user User
  55. if len(pubKey) == 0 {
  56. return user, "", errors.New("Credentials cannot be null or empty")
  57. }
  58. user, err := getUserByUsername(username, dbHandle)
  59. if err != nil {
  60. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  61. return user, "", err
  62. }
  63. return checkUserAndPubKey(user, pubKey)
  64. }
  65. func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
  66. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  67. defer cancel()
  68. return dbHandle.PingContext(ctx)
  69. }
  70. func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
  71. var user User
  72. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  73. defer cancel()
  74. q := getUserByIDQuery()
  75. stmt, err := dbHandle.PrepareContext(ctx, q)
  76. if err != nil {
  77. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  78. return user, err
  79. }
  80. defer stmt.Close()
  81. row := stmt.QueryRowContext(ctx, ID)
  82. user, err = getUserFromDbRow(row, nil)
  83. if err != nil {
  84. return user, err
  85. }
  86. return getUserWithVirtualFolders(user, dbHandle)
  87. }
  88. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  89. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  90. defer cancel()
  91. q := getUpdateQuotaQuery(reset)
  92. stmt, err := dbHandle.PrepareContext(ctx, q)
  93. if err != nil {
  94. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  95. return err
  96. }
  97. defer stmt.Close()
  98. _, err = stmt.ExecContext(ctx, sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  99. if err == nil {
  100. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  101. username, filesAdd, sizeAdd, reset)
  102. } else {
  103. providerLog(logger.LevelWarn, "error updating quota for user %#v: %v", username, err)
  104. }
  105. return err
  106. }
  107. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
  108. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  109. defer cancel()
  110. q := getQuotaQuery()
  111. stmt, err := dbHandle.PrepareContext(ctx, q)
  112. if err != nil {
  113. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  114. return 0, 0, err
  115. }
  116. defer stmt.Close()
  117. var usedFiles int
  118. var usedSize int64
  119. err = stmt.QueryRowContext(ctx, username).Scan(&usedSize, &usedFiles)
  120. if err != nil {
  121. providerLog(logger.LevelWarn, "error getting quota for user: %v, error: %v", username, err)
  122. return 0, 0, err
  123. }
  124. return usedFiles, usedSize, err
  125. }
  126. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  127. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  128. defer cancel()
  129. q := getUpdateLastLoginQuery()
  130. stmt, err := dbHandle.PrepareContext(ctx, q)
  131. if err != nil {
  132. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  133. return err
  134. }
  135. defer stmt.Close()
  136. _, err = stmt.ExecContext(ctx, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  137. if err == nil {
  138. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  139. } else {
  140. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  141. }
  142. return err
  143. }
  144. func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
  145. var user User
  146. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  147. defer cancel()
  148. q := getUserByUsernameQuery()
  149. stmt, err := dbHandle.PrepareContext(ctx, q)
  150. if err != nil {
  151. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  152. return user, err
  153. }
  154. defer stmt.Close()
  155. row := stmt.QueryRowContext(ctx, username)
  156. user, err = getUserFromDbRow(row, nil)
  157. if err != nil {
  158. return user, err
  159. }
  160. return getUserWithVirtualFolders(user, dbHandle)
  161. }
  162. func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
  163. err := validateUser(&user)
  164. if err != nil {
  165. return err
  166. }
  167. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  168. defer cancel()
  169. tx, err := dbHandle.BeginTx(ctx, nil)
  170. if err != nil {
  171. return err
  172. }
  173. q := getAddUserQuery()
  174. stmt, err := tx.PrepareContext(ctx, q)
  175. if err != nil {
  176. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  177. sqlCommonRollbackTransaction(tx)
  178. return err
  179. }
  180. defer stmt.Close()
  181. permissions, err := user.GetPermissionsAsJSON()
  182. if err != nil {
  183. sqlCommonRollbackTransaction(tx)
  184. return err
  185. }
  186. publicKeys, err := user.GetPublicKeysAsJSON()
  187. if err != nil {
  188. sqlCommonRollbackTransaction(tx)
  189. return err
  190. }
  191. filters, err := user.GetFiltersAsJSON()
  192. if err != nil {
  193. sqlCommonRollbackTransaction(tx)
  194. return err
  195. }
  196. fsConfig, err := user.GetFsConfigAsJSON()
  197. if err != nil {
  198. sqlCommonRollbackTransaction(tx)
  199. return err
  200. }
  201. _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  202. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
  203. string(fsConfig), user.AdditionalInfo)
  204. if err != nil {
  205. sqlCommonRollbackTransaction(tx)
  206. return err
  207. }
  208. err = generateVirtualFoldersMapping(ctx, user, tx)
  209. if err != nil {
  210. sqlCommonRollbackTransaction(tx)
  211. return err
  212. }
  213. return tx.Commit()
  214. }
  215. func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
  216. err := validateUser(&user)
  217. if err != nil {
  218. return err
  219. }
  220. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  221. defer cancel()
  222. tx, err := dbHandle.BeginTx(ctx, nil)
  223. if err != nil {
  224. return err
  225. }
  226. q := getUpdateUserQuery()
  227. stmt, err := tx.PrepareContext(ctx, q)
  228. if err != nil {
  229. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  230. sqlCommonRollbackTransaction(tx)
  231. return err
  232. }
  233. defer stmt.Close()
  234. permissions, err := user.GetPermissionsAsJSON()
  235. if err != nil {
  236. sqlCommonRollbackTransaction(tx)
  237. return err
  238. }
  239. publicKeys, err := user.GetPublicKeysAsJSON()
  240. if err != nil {
  241. sqlCommonRollbackTransaction(tx)
  242. return err
  243. }
  244. filters, err := user.GetFiltersAsJSON()
  245. if err != nil {
  246. sqlCommonRollbackTransaction(tx)
  247. return err
  248. }
  249. fsConfig, err := user.GetFsConfigAsJSON()
  250. if err != nil {
  251. sqlCommonRollbackTransaction(tx)
  252. return err
  253. }
  254. _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  255. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
  256. string(filters), string(fsConfig), user.AdditionalInfo, user.ID)
  257. if err != nil {
  258. sqlCommonRollbackTransaction(tx)
  259. return err
  260. }
  261. err = generateVirtualFoldersMapping(ctx, user, tx)
  262. if err != nil {
  263. sqlCommonRollbackTransaction(tx)
  264. return err
  265. }
  266. return tx.Commit()
  267. }
  268. func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
  269. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  270. defer cancel()
  271. q := getDeleteUserQuery()
  272. stmt, err := dbHandle.PrepareContext(ctx, q)
  273. if err != nil {
  274. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  275. return err
  276. }
  277. defer stmt.Close()
  278. _, err = stmt.ExecContext(ctx, user.ID)
  279. return err
  280. }
  281. func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
  282. users := make([]User, 0, 100)
  283. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  284. defer cancel()
  285. q := getDumpUsersQuery()
  286. stmt, err := dbHandle.PrepareContext(ctx, q)
  287. if err != nil {
  288. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  289. return nil, err
  290. }
  291. defer stmt.Close()
  292. rows, err := stmt.QueryContext(ctx)
  293. if err != nil {
  294. return users, err
  295. }
  296. defer rows.Close()
  297. for rows.Next() {
  298. u, err := getUserFromDbRow(nil, rows)
  299. if err != nil {
  300. return users, err
  301. }
  302. err = addCredentialsToUser(&u)
  303. if err != nil {
  304. return users, err
  305. }
  306. users = append(users, u)
  307. }
  308. return getUsersWithVirtualFolders(users, dbHandle)
  309. }
  310. func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle sqlQuerier) ([]User, error) {
  311. users := make([]User, 0, limit)
  312. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  313. defer cancel()
  314. q := getUsersQuery(order, username)
  315. stmt, err := dbHandle.PrepareContext(ctx, q)
  316. if err != nil {
  317. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  318. return nil, err
  319. }
  320. defer stmt.Close()
  321. var rows *sql.Rows
  322. if len(username) > 0 {
  323. rows, err = stmt.QueryContext(ctx, username, limit, offset) //nolint:rowserrcheck // rows.Err() is checked
  324. } else {
  325. rows, err = stmt.QueryContext(ctx, limit, offset) //nolint:rowserrcheck // rows.Err() is checked
  326. }
  327. if err == nil {
  328. defer rows.Close()
  329. for rows.Next() {
  330. u, err := getUserFromDbRow(nil, rows)
  331. if err != nil {
  332. return users, err
  333. }
  334. u.HideConfidentialData()
  335. users = append(users, u)
  336. }
  337. }
  338. err = rows.Err()
  339. if err != nil {
  340. return users, err
  341. }
  342. return getUsersWithVirtualFolders(users, dbHandle)
  343. }
  344. func updateUserPermissionsFromDb(user *User, permissions string) error {
  345. var err error
  346. perms := make(map[string][]string)
  347. err = json.Unmarshal([]byte(permissions), &perms)
  348. if err == nil {
  349. user.Permissions = perms
  350. } else {
  351. // compatibility layer: until version 0.9.4 permissions were a string list
  352. var list []string
  353. err = json.Unmarshal([]byte(permissions), &list)
  354. if err != nil {
  355. return err
  356. }
  357. perms["/"] = list
  358. user.Permissions = perms
  359. }
  360. return err
  361. }
  362. func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
  363. var user User
  364. var permissions sql.NullString
  365. var password sql.NullString
  366. var publicKey sql.NullString
  367. var filters sql.NullString
  368. var fsConfig sql.NullString
  369. var additionalInfo sql.NullString
  370. var err error
  371. if row != nil {
  372. err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  373. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  374. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  375. &additionalInfo)
  376. } else {
  377. err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  378. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  379. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  380. &additionalInfo)
  381. }
  382. if err != nil {
  383. if err == sql.ErrNoRows {
  384. return user, &RecordNotFoundError{err: err.Error()}
  385. }
  386. return user, err
  387. }
  388. if password.Valid {
  389. user.Password = password.String
  390. }
  391. // we can have a empty string or an invalid json in null string
  392. // so we do a relaxed test if the field is optional, for example we
  393. // populate public keys only if unmarshal does not return an error
  394. if publicKey.Valid {
  395. var list []string
  396. err = json.Unmarshal([]byte(publicKey.String), &list)
  397. if err == nil {
  398. user.PublicKeys = list
  399. }
  400. }
  401. if permissions.Valid {
  402. err = updateUserPermissionsFromDb(&user, permissions.String)
  403. if err != nil {
  404. return user, err
  405. }
  406. }
  407. if filters.Valid {
  408. var userFilters UserFilters
  409. err = json.Unmarshal([]byte(filters.String), &userFilters)
  410. if err == nil {
  411. user.Filters = userFilters
  412. }
  413. }
  414. if fsConfig.Valid {
  415. var fs Filesystem
  416. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  417. if err == nil {
  418. user.FsConfig = fs
  419. }
  420. }
  421. if additionalInfo.Valid {
  422. user.AdditionalInfo = additionalInfo.String
  423. }
  424. user.SetEmptySecretsIfNil()
  425. return user, err
  426. }
  427. func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  428. var folder vfs.BaseVirtualFolder
  429. q := getFolderByPathQuery()
  430. stmt, err := dbHandle.PrepareContext(ctx, q)
  431. if err != nil {
  432. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  433. return folder, err
  434. }
  435. defer stmt.Close()
  436. row := stmt.QueryRowContext(ctx, name)
  437. err = row.Scan(&folder.ID, &folder.MappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate)
  438. if err == sql.ErrNoRows {
  439. return folder, &RecordNotFoundError{err: err.Error()}
  440. }
  441. return folder, err
  442. }
  443. func sqlCommonAddOrGetFolder(ctx context.Context, name string, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  444. folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle)
  445. if _, ok := err.(*RecordNotFoundError); ok {
  446. f := vfs.BaseVirtualFolder{
  447. MappedPath: name,
  448. UsedQuotaSize: usedQuotaSize,
  449. UsedQuotaFiles: usedQuotaFiles,
  450. LastQuotaUpdate: lastQuotaUpdate,
  451. }
  452. err = sqlCommonAddFolder(f, dbHandle)
  453. if err != nil {
  454. return folder, err
  455. }
  456. return sqlCommonCheckFolderExists(ctx, name, dbHandle)
  457. }
  458. return folder, err
  459. }
  460. func sqlCommonAddFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  461. err := validateFolder(&folder)
  462. if err != nil {
  463. return err
  464. }
  465. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  466. defer cancel()
  467. q := getAddFolderQuery()
  468. stmt, err := dbHandle.PrepareContext(ctx, q)
  469. if err != nil {
  470. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  471. return err
  472. }
  473. defer stmt.Close()
  474. _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles, folder.LastQuotaUpdate)
  475. return err
  476. }
  477. func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  478. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  479. defer cancel()
  480. q := getDeleteFolderQuery()
  481. stmt, err := dbHandle.PrepareContext(ctx, q)
  482. if err != nil {
  483. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  484. return err
  485. }
  486. defer stmt.Close()
  487. _, err = stmt.ExecContext(ctx, folder.ID)
  488. return err
  489. }
  490. func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  491. folders := make([]vfs.BaseVirtualFolder, 0, 50)
  492. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  493. defer cancel()
  494. q := getDumpFoldersQuery()
  495. stmt, err := dbHandle.PrepareContext(ctx, q)
  496. if err != nil {
  497. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  498. return nil, err
  499. }
  500. defer stmt.Close()
  501. rows, err := stmt.QueryContext(ctx)
  502. if err != nil {
  503. return folders, err
  504. }
  505. defer rows.Close()
  506. for rows.Next() {
  507. var folder vfs.BaseVirtualFolder
  508. err = rows.Scan(&folder.ID, &folder.MappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate)
  509. if err != nil {
  510. return folders, err
  511. }
  512. folders = append(folders, folder)
  513. }
  514. err = rows.Err()
  515. if err != nil {
  516. return folders, err
  517. }
  518. return getVirtualFoldersWithUsers(folders, dbHandle)
  519. }
  520. func sqlCommonGetFolders(limit, offset int, order, folderPath string, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  521. folders := make([]vfs.BaseVirtualFolder, 0, limit)
  522. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  523. defer cancel()
  524. q := getFoldersQuery(order, folderPath)
  525. stmt, err := dbHandle.PrepareContext(ctx, q)
  526. if err != nil {
  527. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  528. return nil, err
  529. }
  530. defer stmt.Close()
  531. var rows *sql.Rows
  532. if len(folderPath) > 0 {
  533. rows, err = stmt.QueryContext(ctx, folderPath, limit, offset) //nolint:rowserrcheck // rows.Err() is checked
  534. } else {
  535. rows, err = stmt.QueryContext(ctx, limit, offset) //nolint:rowserrcheck // rows.Err() is checked
  536. }
  537. if err != nil {
  538. return folders, err
  539. }
  540. defer rows.Close()
  541. for rows.Next() {
  542. var folder vfs.BaseVirtualFolder
  543. err = rows.Scan(&folder.ID, &folder.MappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate)
  544. if err != nil {
  545. return folders, err
  546. }
  547. folders = append(folders, folder)
  548. }
  549. err = rows.Err()
  550. if err != nil {
  551. return folders, err
  552. }
  553. return getVirtualFoldersWithUsers(folders, dbHandle)
  554. }
  555. func sqlCommonClearFolderMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
  556. q := getClearFolderMappingQuery()
  557. stmt, err := dbHandle.PrepareContext(ctx, q)
  558. if err != nil {
  559. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  560. return err
  561. }
  562. defer stmt.Close()
  563. _, err = stmt.ExecContext(ctx, user.Username)
  564. return err
  565. }
  566. func sqlCommonAddFolderMapping(ctx context.Context, user User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
  567. q := getAddFolderMappingQuery()
  568. stmt, err := dbHandle.PrepareContext(ctx, q)
  569. if err != nil {
  570. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  571. return err
  572. }
  573. defer stmt.Close()
  574. _, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.ID, user.Username)
  575. return err
  576. }
  577. func generateVirtualFoldersMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
  578. err := sqlCommonClearFolderMapping(ctx, user, dbHandle)
  579. if err != nil {
  580. return err
  581. }
  582. for _, vfolder := range user.VirtualFolders {
  583. f, err := sqlCommonAddOrGetFolder(ctx, vfolder.MappedPath, 0, 0, 0, dbHandle)
  584. if err != nil {
  585. return err
  586. }
  587. vfolder.BaseVirtualFolder = f
  588. err = sqlCommonAddFolderMapping(ctx, user, vfolder, dbHandle)
  589. if err != nil {
  590. return err
  591. }
  592. }
  593. return err
  594. }
  595. func getUserWithVirtualFolders(user User, dbHandle sqlQuerier) (User, error) {
  596. users, err := getUsersWithVirtualFolders([]User{user}, dbHandle)
  597. if err != nil {
  598. return user, err
  599. }
  600. if len(users) == 0 {
  601. return user, errSQLFoldersAssosaction
  602. }
  603. return users[0], err
  604. }
  605. func getUsersWithVirtualFolders(users []User, dbHandle sqlQuerier) ([]User, error) {
  606. var err error
  607. usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  608. if len(users) == 0 {
  609. return users, err
  610. }
  611. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  612. defer cancel()
  613. q := getRelatedFoldersForUsersQuery(users)
  614. stmt, err := dbHandle.PrepareContext(ctx, q)
  615. if err != nil {
  616. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  617. return nil, err
  618. }
  619. defer stmt.Close()
  620. rows, err := stmt.QueryContext(ctx)
  621. if err != nil {
  622. return nil, err
  623. }
  624. defer rows.Close()
  625. for rows.Next() {
  626. var folder vfs.VirtualFolder
  627. var userID int64
  628. err = rows.Scan(&folder.ID, &folder.MappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  629. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID)
  630. if err != nil {
  631. return users, err
  632. }
  633. usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
  634. }
  635. err = rows.Err()
  636. if err != nil {
  637. return users, err
  638. }
  639. if len(usersVirtualFolders) == 0 {
  640. return users, err
  641. }
  642. for idx := range users {
  643. ref := &users[idx]
  644. ref.VirtualFolders = usersVirtualFolders[ref.ID]
  645. }
  646. return users, err
  647. }
  648. func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  649. var err error
  650. vFoldersUsers := make(map[int64][]string)
  651. if len(folders) == 0 {
  652. return folders, err
  653. }
  654. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  655. defer cancel()
  656. q := getRelatedUsersForFoldersQuery(folders)
  657. stmt, err := dbHandle.PrepareContext(ctx, q)
  658. if err != nil {
  659. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  660. return nil, err
  661. }
  662. defer stmt.Close()
  663. rows, err := stmt.QueryContext(ctx)
  664. if err != nil {
  665. return nil, err
  666. }
  667. defer rows.Close()
  668. for rows.Next() {
  669. var username string
  670. var folderID int64
  671. err = rows.Scan(&folderID, &username)
  672. if err != nil {
  673. return folders, err
  674. }
  675. vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username)
  676. }
  677. err = rows.Err()
  678. if err != nil {
  679. return folders, err
  680. }
  681. if len(vFoldersUsers) == 0 {
  682. return folders, err
  683. }
  684. for idx := range folders {
  685. ref := &folders[idx]
  686. ref.Users = vFoldersUsers[ref.ID]
  687. }
  688. return folders, err
  689. }
  690. func sqlCommonUpdateFolderQuota(mappedPath string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  691. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  692. defer cancel()
  693. q := getUpdateFolderQuotaQuery(reset)
  694. stmt, err := dbHandle.PrepareContext(ctx, q)
  695. if err != nil {
  696. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  697. return err
  698. }
  699. defer stmt.Close()
  700. _, err = stmt.ExecContext(ctx, sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), mappedPath)
  701. if err == nil {
  702. providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
  703. mappedPath, filesAdd, sizeAdd, reset)
  704. } else {
  705. providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", mappedPath, err)
  706. }
  707. return err
  708. }
  709. func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) {
  710. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  711. defer cancel()
  712. q := getQuotaFolderQuery()
  713. stmt, err := dbHandle.PrepareContext(ctx, q)
  714. if err != nil {
  715. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  716. return 0, 0, err
  717. }
  718. defer stmt.Close()
  719. var usedFiles int
  720. var usedSize int64
  721. err = stmt.QueryRowContext(ctx, mappedPath).Scan(&usedSize, &usedFiles)
  722. if err != nil {
  723. providerLog(logger.LevelWarn, "error getting quota for folder: %v, error: %v", mappedPath, err)
  724. return 0, 0, err
  725. }
  726. return usedFiles, usedSize, err
  727. }
  728. func sqlCommonRollbackTransaction(tx *sql.Tx) {
  729. err := tx.Rollback()
  730. if err != nil {
  731. providerLog(logger.LevelWarn, "error rolling back transaction: %v", err)
  732. }
  733. }
  734. func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
  735. var result schemaVersion
  736. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  737. defer cancel()
  738. q := getDatabaseVersionQuery()
  739. stmt, err := dbHandle.PrepareContext(ctx, q)
  740. if err != nil {
  741. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  742. if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
  743. logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
  744. }
  745. return result, err
  746. }
  747. defer stmt.Close()
  748. row := stmt.QueryRowContext(ctx)
  749. err = row.Scan(&result.Version)
  750. return result, err
  751. }
  752. func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
  753. q := getUpdateDBVersionQuery()
  754. stmt, err := dbHandle.PrepareContext(ctx, q)
  755. if err != nil {
  756. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  757. return err
  758. }
  759. defer stmt.Close()
  760. _, err = stmt.ExecContext(ctx, version)
  761. return err
  762. }
  763. func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersion int) error {
  764. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  765. defer cancel()
  766. tx, err := dbHandle.BeginTx(ctx, nil)
  767. if err != nil {
  768. return err
  769. }
  770. for _, q := range sql {
  771. if len(strings.TrimSpace(q)) == 0 {
  772. continue
  773. }
  774. _, err = tx.ExecContext(ctx, q)
  775. if err != nil {
  776. sqlCommonRollbackTransaction(tx)
  777. return err
  778. }
  779. }
  780. err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
  781. if err != nil {
  782. sqlCommonRollbackTransaction(tx)
  783. return err
  784. }
  785. return tx.Commit()
  786. }
  787. func sqlCommonGetCompatVirtualFolders(dbHandle *sql.DB) ([]userCompactVFolders, error) {
  788. users := []userCompactVFolders{}
  789. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  790. defer cancel()
  791. q := getCompatVirtualFoldersQuery()
  792. stmt, err := dbHandle.PrepareContext(ctx, q)
  793. if err != nil {
  794. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  795. return nil, err
  796. }
  797. defer stmt.Close()
  798. rows, err := stmt.QueryContext(ctx)
  799. if err != nil {
  800. return nil, err
  801. }
  802. defer rows.Close()
  803. for rows.Next() {
  804. var user userCompactVFolders
  805. var virtualFolders sql.NullString
  806. err = rows.Scan(&user.ID, &user.Username, &virtualFolders)
  807. if err != nil {
  808. return nil, err
  809. }
  810. if virtualFolders.Valid {
  811. var list []virtualFoldersCompact
  812. err = json.Unmarshal([]byte(virtualFolders.String), &list)
  813. if err == nil && len(list) > 0 {
  814. user.VirtualFolders = list
  815. users = append(users, user)
  816. }
  817. }
  818. }
  819. return users, rows.Err()
  820. }
  821. func sqlCommonRestoreCompatVirtualFolders(ctx context.Context, users []userCompactVFolders, dbHandle sqlQuerier) ([]string, error) {
  822. foldersToScan := []string{}
  823. for _, user := range users {
  824. for _, vfolder := range user.VirtualFolders {
  825. providerLog(logger.LevelInfo, "restoring virtual folder: %+v for user %#v", vfolder, user.Username)
  826. // -1 means included in user quota, 0 means unlimited
  827. quotaSize := int64(-1)
  828. quotaFiles := -1
  829. if vfolder.ExcludeFromQuota {
  830. quotaFiles = 0
  831. quotaSize = 0
  832. }
  833. b, err := sqlCommonAddOrGetFolder(ctx, vfolder.MappedPath, 0, 0, 0, dbHandle)
  834. if err != nil {
  835. providerLog(logger.LevelWarn, "error restoring virtual folder for user %#v: %v", user.Username, err)
  836. return foldersToScan, err
  837. }
  838. u := User{
  839. ID: user.ID,
  840. Username: user.Username,
  841. }
  842. f := vfs.VirtualFolder{
  843. BaseVirtualFolder: b,
  844. VirtualPath: vfolder.VirtualPath,
  845. QuotaSize: quotaSize,
  846. QuotaFiles: quotaFiles,
  847. }
  848. err = sqlCommonAddFolderMapping(ctx, u, f, dbHandle)
  849. if err != nil {
  850. providerLog(logger.LevelWarn, "error adding virtual folder mapping for user %#v: %v", user.Username, err)
  851. return foldersToScan, err
  852. }
  853. if !utils.IsStringInSlice(vfolder.MappedPath, foldersToScan) {
  854. foldersToScan = append(foldersToScan, vfolder.MappedPath)
  855. }
  856. providerLog(logger.LevelInfo, "virtual folder: %+v for user %#v successfully restored", vfolder, user.Username)
  857. }
  858. }
  859. return foldersToScan, nil
  860. }
  861. func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
  862. logger.InfoToConsole("updating database version: 3 -> 4")
  863. providerLog(logger.LevelInfo, "updating database version: 3 -> 4")
  864. users, err := sqlCommonGetCompatVirtualFolders(dbHandle)
  865. if err != nil {
  866. return err
  867. }
  868. sql := strings.ReplaceAll(sqlV4, "{{users}}", sqlTableUsers)
  869. sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
  870. sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
  871. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  872. defer cancel()
  873. tx, err := dbHandle.BeginTx(ctx, nil)
  874. if err != nil {
  875. return err
  876. }
  877. for _, q := range strings.Split(sql, ";") {
  878. if len(strings.TrimSpace(q)) == 0 {
  879. continue
  880. }
  881. _, err = tx.ExecContext(ctx, q)
  882. if err != nil {
  883. sqlCommonRollbackTransaction(tx)
  884. return err
  885. }
  886. }
  887. foldersToScan, err := sqlCommonRestoreCompatVirtualFolders(ctx, users, tx)
  888. if err != nil {
  889. sqlCommonRollbackTransaction(tx)
  890. return err
  891. }
  892. err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4)
  893. if err != nil {
  894. sqlCommonRollbackTransaction(tx)
  895. return err
  896. }
  897. err = tx.Commit()
  898. if err == nil {
  899. go updateVFoldersQuotaAfterRestore(foldersToScan)
  900. }
  901. return err
  902. }
  903. //nolint:dupl
  904. func sqlCommonUpdateDatabaseFrom4To5(dbHandle *sql.DB) error {
  905. logger.InfoToConsole("updating database version: 4 -> 5")
  906. providerLog(logger.LevelInfo, "updating database version: 4 -> 5")
  907. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  908. defer cancel()
  909. q := getCompatV4FsConfigQuery()
  910. stmt, err := dbHandle.PrepareContext(ctx, q)
  911. if err != nil {
  912. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  913. return err
  914. }
  915. defer stmt.Close()
  916. rows, err := stmt.QueryContext(ctx)
  917. if err != nil {
  918. return err
  919. }
  920. defer rows.Close()
  921. users := []User{}
  922. for rows.Next() {
  923. var compatUser compatUserV4
  924. var fsConfigString sql.NullString
  925. err = rows.Scan(&compatUser.ID, &compatUser.Username, &fsConfigString)
  926. if err != nil {
  927. return err
  928. }
  929. if fsConfigString.Valid {
  930. err = json.Unmarshal([]byte(fsConfigString.String), &compatUser.FsConfig)
  931. if err != nil {
  932. logger.WarnToConsole("failed to unmarshal v4 user %#v, is it already migrated?", compatUser.Username)
  933. continue
  934. }
  935. fsConfig, err := convertFsConfigFromV4(compatUser.FsConfig, compatUser.Username)
  936. if err != nil {
  937. return err
  938. }
  939. users = append(users, createUserFromV4(compatUser, fsConfig))
  940. }
  941. }
  942. if err := rows.Err(); err != nil {
  943. return err
  944. }
  945. for _, user := range users {
  946. err = sqlCommonUpdateV4User(dbHandle, user)
  947. if err != nil {
  948. return err
  949. }
  950. providerLog(logger.LevelInfo, "filesystem config updated for user %#v", user.Username)
  951. }
  952. ctxVersion, cancelVersion := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  953. defer cancelVersion()
  954. return sqlCommonUpdateDatabaseVersion(ctxVersion, dbHandle, 5)
  955. }
  956. func sqlCommonUpdateV4CompatUser(dbHandle *sql.DB, user compatUserV4) error {
  957. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  958. defer cancel()
  959. q := updateCompatV4FsConfigQuery()
  960. stmt, err := dbHandle.PrepareContext(ctx, q)
  961. if err != nil {
  962. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  963. return err
  964. }
  965. defer stmt.Close()
  966. fsConfig, err := json.Marshal(user.FsConfig)
  967. if err != nil {
  968. return err
  969. }
  970. _, err = stmt.ExecContext(ctx, string(fsConfig), user.ID)
  971. return err
  972. }
  973. func sqlCommonUpdateV4User(dbHandle *sql.DB, user User) error {
  974. err := validateFilesystemConfig(&user)
  975. if err != nil {
  976. return err
  977. }
  978. err = saveGCSCredentials(&user)
  979. if err != nil {
  980. return err
  981. }
  982. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  983. defer cancel()
  984. q := updateCompatV4FsConfigQuery()
  985. stmt, err := dbHandle.PrepareContext(ctx, q)
  986. if err != nil {
  987. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  988. return err
  989. }
  990. defer stmt.Close()
  991. fsConfig, err := user.GetFsConfigAsJSON()
  992. if err != nil {
  993. return err
  994. }
  995. _, err = stmt.ExecContext(ctx, string(fsConfig), user.ID)
  996. return err
  997. }
  998. //nolint:dupl
  999. func sqlCommonDowngradeDatabaseFrom5To4(dbHandle *sql.DB) error {
  1000. logger.InfoToConsole("downgrading database version: 5 -> 4")
  1001. providerLog(logger.LevelInfo, "downgrading database version: 5 -> 4")
  1002. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1003. defer cancel()
  1004. q := getCompatV4FsConfigQuery()
  1005. stmt, err := dbHandle.PrepareContext(ctx, q)
  1006. if err != nil {
  1007. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1008. return err
  1009. }
  1010. defer stmt.Close()
  1011. rows, err := stmt.QueryContext(ctx)
  1012. if err != nil {
  1013. return err
  1014. }
  1015. defer rows.Close()
  1016. users := []compatUserV4{}
  1017. for rows.Next() {
  1018. var user User
  1019. var fsConfigString sql.NullString
  1020. err = rows.Scan(&user.ID, &user.Username, &fsConfigString)
  1021. if err != nil {
  1022. return err
  1023. }
  1024. if fsConfigString.Valid {
  1025. err = json.Unmarshal([]byte(fsConfigString.String), &user.FsConfig)
  1026. if err != nil {
  1027. logger.WarnToConsole("failed to unmarshal user %#v to v4, is it already migrated?", user.Username)
  1028. continue
  1029. }
  1030. fsConfig, err := convertFsConfigToV4(user.FsConfig, user.Username)
  1031. if err != nil {
  1032. return err
  1033. }
  1034. users = append(users, convertUserToV4(user, fsConfig))
  1035. }
  1036. }
  1037. if err := rows.Err(); err != nil {
  1038. return err
  1039. }
  1040. for _, user := range users {
  1041. err = sqlCommonUpdateV4CompatUser(dbHandle, user)
  1042. if err != nil {
  1043. return err
  1044. }
  1045. providerLog(logger.LevelInfo, "filesystem config downgraded for user %#v", user.Username)
  1046. }
  1047. ctxVersion, cancelVersion := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1048. defer cancelVersion()
  1049. return sqlCommonUpdateDatabaseVersion(ctxVersion, dbHandle, 4)
  1050. }