sqlcommon.go 97 KB


  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "context"
  17. "crypto/x509"
  18. "database/sql"
  19. "encoding/json"
  20. "errors"
  21. "fmt"
  22. "runtime/debug"
  23. "strings"
  24. "time"
  25. "github.com/cockroachdb/cockroach-go/v2/crdb"
  26. "github.com/sftpgo/sdk"
  27. "github.com/drakkan/sftpgo/v2/internal/logger"
  28. "github.com/drakkan/sftpgo/v2/internal/util"
  29. "github.com/drakkan/sftpgo/v2/internal/vfs"
  30. )
  31. const (
  32. sqlDatabaseVersion = 23
  33. defaultSQLQueryTimeout = 10 * time.Second
  34. longSQLQueryTimeout = 60 * time.Second
  35. )
  36. var (
  37. errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user")
  38. errSQLGroupsAssociation = errors.New("unable to associate groups to user")
  39. errSQLUsersAssociation = errors.New("unable to associate users to group")
  40. errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
  41. )
  42. type sqlQuerier interface {
  43. QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
  44. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  45. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  46. PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
  47. }
  48. type sqlScanner interface {
  49. Scan(dest ...any) error
  50. }
  51. func sqlReplaceAll(sql string) string {
  52. sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion)
  53. sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
  54. sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
  55. sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
  56. sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
  57. sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
  58. sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
  59. sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
  60. sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
  61. sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
  62. sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
  63. sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
  64. sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
  65. sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
  66. sql = strings.ReplaceAll(sql, "{{shared_sessions}}", sqlTableSharedSessions)
  67. sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions)
  68. sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
  69. sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
  70. sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
  71. sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
  72. sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
  73. return sql
  74. }
  75. func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) {
  76. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  77. defer cancel()
  78. filterUser := username != ""
  79. q := getShareByIDQuery(filterUser)
  80. var row *sql.Row
  81. if filterUser {
  82. row = dbHandle.QueryRowContext(ctx, q, shareID, username)
  83. } else {
  84. row = dbHandle.QueryRowContext(ctx, q, shareID)
  85. }
  86. return getShareFromDbRow(row)
  87. }
  88. func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
  89. err := share.validate()
  90. if err != nil {
  91. return err
  92. }
  93. user, err := provider.userExists(share.Username)
  94. if err != nil {
  95. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  96. }
  97. paths, err := json.Marshal(share.Paths)
  98. if err != nil {
  99. return err
  100. }
  101. allowFrom := ""
  102. if len(share.AllowFrom) > 0 {
  103. res, err := json.Marshal(share.AllowFrom)
  104. if err == nil {
  105. allowFrom = string(res)
  106. }
  107. }
  108. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  109. defer cancel()
  110. q := getAddShareQuery()
  111. usedTokens := 0
  112. createdAt := util.GetTimeAsMsSinceEpoch(time.Now())
  113. updatedAt := createdAt
  114. lastUseAt := int64(0)
  115. if share.IsRestore {
  116. usedTokens = share.UsedTokens
  117. if share.CreatedAt > 0 {
  118. createdAt = share.CreatedAt
  119. }
  120. if share.UpdatedAt > 0 {
  121. updatedAt = share.UpdatedAt
  122. }
  123. lastUseAt = share.LastUseAt
  124. }
  125. _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
  126. string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
  127. share.MaxTokens, usedTokens, allowFrom, user.ID)
  128. return err
  129. }
  130. func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
  131. err := share.validate()
  132. if err != nil {
  133. return err
  134. }
  135. paths, err := json.Marshal(share.Paths)
  136. if err != nil {
  137. return err
  138. }
  139. allowFrom := ""
  140. if len(share.AllowFrom) > 0 {
  141. res, err := json.Marshal(share.AllowFrom)
  142. if err == nil {
  143. allowFrom = string(res)
  144. }
  145. }
  146. user, err := provider.userExists(share.Username)
  147. if err != nil {
  148. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  149. }
  150. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  151. defer cancel()
  152. var q string
  153. if share.IsRestore {
  154. q = getUpdateShareRestoreQuery()
  155. } else {
  156. q = getUpdateShareQuery()
  157. }
  158. if share.IsRestore {
  159. if share.CreatedAt == 0 {
  160. share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  161. }
  162. if share.UpdatedAt == 0 {
  163. share.UpdatedAt = share.CreatedAt
  164. }
  165. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  166. share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
  167. share.UsedTokens, allowFrom, user.ID, share.ShareID)
  168. } else {
  169. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  170. util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
  171. allowFrom, user.ID, share.ShareID)
  172. }
  173. return err
  174. }
  175. func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
  176. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  177. defer cancel()
  178. q := getDeleteShareQuery()
  179. res, err := dbHandle.ExecContext(ctx, q, share.ShareID)
  180. if err != nil {
  181. return err
  182. }
  183. return sqlCommonRequireRowAffected(res)
  184. }
  185. func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) {
  186. shares := make([]Share, 0, limit)
  187. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  188. defer cancel()
  189. q := getSharesQuery(order)
  190. rows, err := dbHandle.QueryContext(ctx, q, username, limit, offset)
  191. if err != nil {
  192. return shares, err
  193. }
  194. defer rows.Close()
  195. for rows.Next() {
  196. s, err := getShareFromDbRow(rows)
  197. if err != nil {
  198. return shares, err
  199. }
  200. s.HideConfidentialData()
  201. shares = append(shares, s)
  202. }
  203. return shares, rows.Err()
  204. }
  205. func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) {
  206. shares := make([]Share, 0, 30)
  207. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  208. defer cancel()
  209. q := getDumpSharesQuery()
  210. rows, err := dbHandle.QueryContext(ctx, q)
  211. if err != nil {
  212. return shares, err
  213. }
  214. defer rows.Close()
  215. for rows.Next() {
  216. s, err := getShareFromDbRow(rows)
  217. if err != nil {
  218. return shares, err
  219. }
  220. shares = append(shares, s)
  221. }
  222. return shares, rows.Err()
  223. }
  224. func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
  225. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  226. defer cancel()
  227. q := getAPIKeyByIDQuery()
  228. row := dbHandle.QueryRowContext(ctx, q, keyID)
  229. apiKey, err := getAPIKeyFromDbRow(row)
  230. if err != nil {
  231. return apiKey, err
  232. }
  233. return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
  234. }
  235. func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  236. err := apiKey.validate()
  237. if err != nil {
  238. return err
  239. }
  240. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  241. if err != nil {
  242. return err
  243. }
  244. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  245. defer cancel()
  246. q := getAddAPIKeyQuery()
  247. _, err = dbHandle.ExecContext(ctx, q, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope,
  248. util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt,
  249. apiKey.ExpiresAt, apiKey.Description, userID, adminID)
  250. return err
  251. }
  252. func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  253. err := apiKey.validate()
  254. if err != nil {
  255. return err
  256. }
  257. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  258. if err != nil {
  259. return err
  260. }
  261. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  262. defer cancel()
  263. q := getUpdateAPIKeyQuery()
  264. _, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
  265. apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
  266. return err
  267. }
  268. func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
  269. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  270. defer cancel()
  271. q := getDeleteAPIKeyQuery()
  272. res, err := dbHandle.ExecContext(ctx, q, apiKey.KeyID)
  273. if err != nil {
  274. return err
  275. }
  276. return sqlCommonRequireRowAffected(res)
  277. }
  278. func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
  279. apiKeys := make([]APIKey, 0, limit)
  280. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  281. defer cancel()
  282. q := getAPIKeysQuery(order)
  283. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  284. if err != nil {
  285. return apiKeys, err
  286. }
  287. defer rows.Close()
  288. for rows.Next() {
  289. k, err := getAPIKeyFromDbRow(rows)
  290. if err != nil {
  291. return apiKeys, err
  292. }
  293. k.HideConfidentialData()
  294. apiKeys = append(apiKeys, k)
  295. }
  296. err = rows.Err()
  297. if err != nil {
  298. return apiKeys, err
  299. }
  300. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  301. if err != nil {
  302. return apiKeys, err
  303. }
  304. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  305. }
  306. func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
  307. apiKeys := make([]APIKey, 0, 30)
  308. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  309. defer cancel()
  310. q := getDumpAPIKeysQuery()
  311. rows, err := dbHandle.QueryContext(ctx, q)
  312. if err != nil {
  313. return apiKeys, err
  314. }
  315. defer rows.Close()
  316. for rows.Next() {
  317. k, err := getAPIKeyFromDbRow(rows)
  318. if err != nil {
  319. return apiKeys, err
  320. }
  321. apiKeys = append(apiKeys, k)
  322. }
  323. err = rows.Err()
  324. if err != nil {
  325. return apiKeys, err
  326. }
  327. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  328. if err != nil {
  329. return apiKeys, err
  330. }
  331. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  332. }
  333. func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
  334. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  335. defer cancel()
  336. q := getAdminByUsernameQuery()
  337. row := dbHandle.QueryRowContext(ctx, q, username)
  338. admin, err := getAdminFromDbRow(row)
  339. if err != nil {
  340. return admin, err
  341. }
  342. return getAdminWithGroups(ctx, admin, dbHandle)
  343. }
  344. func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) {
  345. admin, err := sqlCommonGetAdminByUsername(username, dbHandle)
  346. if err != nil {
  347. providerLog(logger.LevelWarn, "error authenticating admin %#v: %v", username, err)
  348. return admin, ErrInvalidCredentials
  349. }
  350. err = admin.checkUserAndPass(password, ip)
  351. return admin, err
  352. }
  353. func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
  354. err := admin.validate()
  355. if err != nil {
  356. return err
  357. }
  358. perms, err := json.Marshal(admin.Permissions)
  359. if err != nil {
  360. return err
  361. }
  362. filters, err := json.Marshal(admin.Filters)
  363. if err != nil {
  364. return err
  365. }
  366. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  367. defer cancel()
  368. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  369. q := getAddAdminQuery()
  370. _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
  371. string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  372. util.GetTimeAsMsSinceEpoch(time.Now()))
  373. if err != nil {
  374. return err
  375. }
  376. return generateAdminGroupMapping(ctx, admin, tx)
  377. })
  378. }
  379. func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
  380. err := admin.validate()
  381. if err != nil {
  382. return err
  383. }
  384. perms, err := json.Marshal(admin.Permissions)
  385. if err != nil {
  386. return err
  387. }
  388. filters, err := json.Marshal(admin.Filters)
  389. if err != nil {
  390. return err
  391. }
  392. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  393. defer cancel()
  394. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  395. q := getUpdateAdminQuery()
  396. _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
  397. admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Username)
  398. if err != nil {
  399. return err
  400. }
  401. return generateAdminGroupMapping(ctx, admin, tx)
  402. })
  403. }
  404. func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error {
  405. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  406. defer cancel()
  407. q := getDeleteAdminQuery()
  408. res, err := dbHandle.ExecContext(ctx, q, admin.Username)
  409. if err != nil {
  410. return err
  411. }
  412. return sqlCommonRequireRowAffected(res)
  413. }
  414. func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) {
  415. admins := make([]Admin, 0, limit)
  416. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  417. defer cancel()
  418. q := getAdminsQuery(order)
  419. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  420. if err != nil {
  421. return admins, err
  422. }
  423. defer rows.Close()
  424. for rows.Next() {
  425. a, err := getAdminFromDbRow(rows)
  426. if err != nil {
  427. return admins, err
  428. }
  429. a.HideConfidentialData()
  430. admins = append(admins, a)
  431. }
  432. err = rows.Err()
  433. if err != nil {
  434. return admins, err
  435. }
  436. return getAdminsWithGroups(ctx, admins, dbHandle)
  437. }
  438. func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
  439. admins := make([]Admin, 0, 30)
  440. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  441. defer cancel()
  442. q := getDumpAdminsQuery()
  443. rows, err := dbHandle.QueryContext(ctx, q)
  444. if err != nil {
  445. return admins, err
  446. }
  447. defer rows.Close()
  448. for rows.Next() {
  449. a, err := getAdminFromDbRow(rows)
  450. if err != nil {
  451. return admins, err
  452. }
  453. admins = append(admins, a)
  454. }
  455. err = rows.Err()
  456. if err != nil {
  457. return admins, err
  458. }
  459. return getAdminsWithGroups(ctx, admins, dbHandle)
  460. }
  461. func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) {
  462. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  463. defer cancel()
  464. q := getGroupByNameQuery()
  465. row := dbHandle.QueryRowContext(ctx, q, name)
  466. group, err := getGroupFromDbRow(row)
  467. if err != nil {
  468. return group, err
  469. }
  470. group, err = getGroupWithVirtualFolders(ctx, group, dbHandle)
  471. if err != nil {
  472. return group, err
  473. }
  474. group, err = getGroupWithUsers(ctx, group, dbHandle)
  475. if err != nil {
  476. return group, err
  477. }
  478. return getGroupWithAdmins(ctx, group, dbHandle)
  479. }
  480. func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) {
  481. groups := make([]Group, 0, 50)
  482. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  483. defer cancel()
  484. q := getDumpGroupsQuery()
  485. rows, err := dbHandle.QueryContext(ctx, q)
  486. if err != nil {
  487. return groups, err
  488. }
  489. defer rows.Close()
  490. for rows.Next() {
  491. group, err := getGroupFromDbRow(rows)
  492. if err != nil {
  493. return groups, err
  494. }
  495. groups = append(groups, group)
  496. }
  497. err = rows.Err()
  498. if err != nil {
  499. return groups, err
  500. }
  501. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  502. }
  503. func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) {
  504. if len(names) == 0 {
  505. return nil, nil
  506. }
  507. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  508. defer cancel()
  509. q := getUsersInGroupsQuery(len(names))
  510. args := make([]any, 0, len(names))
  511. for _, name := range names {
  512. args = append(args, name)
  513. }
  514. usernames := make([]string, 0, len(names))
  515. rows, err := dbHandle.QueryContext(ctx, q, args...)
  516. if err != nil {
  517. return nil, err
  518. }
  519. defer rows.Close()
  520. for rows.Next() {
  521. var username string
  522. err = rows.Scan(&username)
  523. if err != nil {
  524. return usernames, err
  525. }
  526. usernames = append(usernames, username)
  527. }
  528. return usernames, rows.Err()
  529. }
  530. func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) {
  531. if len(names) == 0 {
  532. return nil, nil
  533. }
  534. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  535. defer cancel()
  536. q := getGroupsWithNamesQuery(len(names))
  537. args := make([]any, 0, len(names))
  538. for _, name := range names {
  539. args = append(args, name)
  540. }
  541. groups := make([]Group, 0, len(names))
  542. rows, err := dbHandle.QueryContext(ctx, q, args...)
  543. if err != nil {
  544. return groups, err
  545. }
  546. defer rows.Close()
  547. for rows.Next() {
  548. group, err := getGroupFromDbRow(rows)
  549. if err != nil {
  550. return groups, err
  551. }
  552. groups = append(groups, group)
  553. }
  554. err = rows.Err()
  555. if err != nil {
  556. return groups, err
  557. }
  558. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  559. }
  560. func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) {
  561. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  562. defer cancel()
  563. q := getGroupsQuery(order, minimal)
  564. groups := make([]Group, 0, limit)
  565. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  566. if err != nil {
  567. return groups, err
  568. }
  569. defer rows.Close()
  570. for rows.Next() {
  571. var group Group
  572. if minimal {
  573. err = rows.Scan(&group.ID, &group.Name)
  574. } else {
  575. group, err = getGroupFromDbRow(rows)
  576. }
  577. if err != nil {
  578. return groups, err
  579. }
  580. groups = append(groups, group)
  581. }
  582. err = rows.Err()
  583. if err != nil {
  584. return groups, err
  585. }
  586. if minimal {
  587. return groups, nil
  588. }
  589. groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  590. if err != nil {
  591. return groups, err
  592. }
  593. groups, err = getGroupsWithUsers(ctx, groups, dbHandle)
  594. if err != nil {
  595. return groups, err
  596. }
  597. groups, err = getGroupsWithAdmins(ctx, groups, dbHandle)
  598. if err != nil {
  599. return groups, err
  600. }
  601. for idx := range groups {
  602. groups[idx].PrepareForRendering()
  603. }
  604. return groups, nil
  605. }
  606. func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
  607. if err := group.validate(); err != nil {
  608. return err
  609. }
  610. settings, err := json.Marshal(group.UserSettings)
  611. if err != nil {
  612. return err
  613. }
  614. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  615. defer cancel()
  616. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  617. q := getAddGroupQuery()
  618. _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  619. util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
  620. if err != nil {
  621. return err
  622. }
  623. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  624. })
  625. }
  626. func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error {
  627. if err := group.validate(); err != nil {
  628. return err
  629. }
  630. settings, err := json.Marshal(group.UserSettings)
  631. if err != nil {
  632. return err
  633. }
  634. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  635. defer cancel()
  636. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  637. q := getUpdateGroupQuery()
  638. _, err := tx.ExecContext(ctx, q, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name)
  639. if err != nil {
  640. return err
  641. }
  642. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  643. })
  644. }
  645. func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
  646. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  647. defer cancel()
  648. q := getDeleteGroupQuery()
  649. res, err := dbHandle.ExecContext(ctx, q, group.Name)
  650. if err != nil {
  651. return err
  652. }
  653. return sqlCommonRequireRowAffected(res)
  654. }
  655. func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
  656. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  657. defer cancel()
  658. q := getUserByUsernameQuery()
  659. row := dbHandle.QueryRowContext(ctx, q, username)
  660. user, err := getUserFromDbRow(row)
  661. if err != nil {
  662. return user, err
  663. }
  664. user, err = getUserWithVirtualFolders(ctx, user, dbHandle)
  665. if err != nil {
  666. return user, err
  667. }
  668. return getUserWithGroups(ctx, user, dbHandle)
  669. }
  670. func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
  671. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  672. if err != nil {
  673. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  674. return user, err
  675. }
  676. return checkUserAndPass(&user, password, ip, protocol)
  677. }
  678. func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) {
  679. var user User
  680. if tlsCert == nil {
  681. return user, errors.New("TLS certificate cannot be null or empty")
  682. }
  683. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  684. if err != nil {
  685. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  686. return user, err
  687. }
  688. return checkUserAndTLSCertificate(&user, protocol, tlsCert)
  689. }
  690. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bool, dbHandle *sql.DB) (User, string, error) {
  691. var user User
  692. if len(pubKey) == 0 {
  693. return user, "", errors.New("credentials cannot be null or empty")
  694. }
  695. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  696. if err != nil {
  697. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  698. return user, "", err
  699. }
  700. return checkUserAndPubKey(&user, pubKey, isSSHCert)
  701. }
  702. func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
  703. defer func() {
  704. if r := recover(); r != nil {
  705. providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
  706. err = errors.New("unable to check provider status")
  707. }
  708. }()
  709. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  710. defer cancel()
  711. err = dbHandle.PingContext(ctx)
  712. return
  713. }
  714. func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error {
  715. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  716. defer cancel()
  717. q := getUpdateTransferQuotaQuery(reset)
  718. _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  719. if err == nil {
  720. providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
  721. username, uploadSize, downloadSize, reset)
  722. } else {
  723. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  724. }
  725. return err
  726. }
  727. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  728. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  729. defer cancel()
  730. q := getUpdateQuotaQuery(reset)
  731. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  732. if err == nil {
  733. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  734. username, filesAdd, sizeAdd, reset)
  735. } else {
  736. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  737. }
  738. return err
  739. }
  740. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) {
  741. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  742. defer cancel()
  743. q := getQuotaQuery()
  744. var usedFiles int
  745. var usedSize, usedUploadSize, usedDownloadSize int64
  746. err := dbHandle.QueryRowContext(ctx, q, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize)
  747. if err != nil {
  748. providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err)
  749. return 0, 0, 0, 0, err
  750. }
  751. return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err
  752. }
  753. func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error {
  754. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  755. defer cancel()
  756. q := getUpdateShareLastUseQuery()
  757. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
  758. if err == nil {
  759. providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
  760. } else {
  761. providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
  762. }
  763. return err
  764. }
  765. func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
  766. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  767. defer cancel()
  768. q := getUpdateAPIKeyLastUseQuery()
  769. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
  770. if err == nil {
  771. providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
  772. } else {
  773. providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
  774. }
  775. return err
  776. }
  777. func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
  778. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  779. defer cancel()
  780. q := getUpdateAdminLastLoginQuery()
  781. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  782. if err == nil {
  783. providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
  784. } else {
  785. providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
  786. }
  787. return err
  788. }
  789. func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
  790. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  791. defer cancel()
  792. q := getSetUpdateAtQuery()
  793. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  794. if err == nil {
  795. providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
  796. } else {
  797. providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
  798. }
  799. }
  800. func sqlCommonSetFirstDownloadTimestamp(username string, dbHandle *sql.DB) error {
  801. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  802. defer cancel()
  803. q := getSetFirstDownloadQuery()
  804. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  805. if err != nil {
  806. return err
  807. }
  808. return sqlCommonRequireRowAffected(res)
  809. }
  810. func sqlCommonSetFirstUploadTimestamp(username string, dbHandle *sql.DB) error {
  811. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  812. defer cancel()
  813. q := getSetFirstUploadQuery()
  814. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  815. if err != nil {
  816. return err
  817. }
  818. return sqlCommonRequireRowAffected(res)
  819. }
  820. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  821. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  822. defer cancel()
  823. q := getUpdateLastLoginQuery()
  824. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  825. if err == nil {
  826. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  827. } else {
  828. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  829. }
  830. return err
  831. }
  832. func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
  833. err := ValidateUser(user)
  834. if err != nil {
  835. return err
  836. }
  837. permissions, err := user.GetPermissionsAsJSON()
  838. if err != nil {
  839. return err
  840. }
  841. publicKeys, err := user.GetPublicKeysAsJSON()
  842. if err != nil {
  843. return err
  844. }
  845. filters, err := user.GetFiltersAsJSON()
  846. if err != nil {
  847. return err
  848. }
  849. fsConfig, err := user.GetFsConfigAsJSON()
  850. if err != nil {
  851. return err
  852. }
  853. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  854. defer cancel()
  855. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  856. q := getAddUserQuery()
  857. _, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
  858. user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
  859. user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
  860. user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
  861. user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer)
  862. if err != nil {
  863. return err
  864. }
  865. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  866. return err
  867. }
  868. return generateUserGroupMapping(ctx, user, tx)
  869. })
  870. }
  871. func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) error {
  872. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  873. defer cancel()
  874. q := getUpdateUserPasswordQuery()
  875. _, err := dbHandle.ExecContext(ctx, q, password, username)
  876. return err
  877. }
  878. func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
  879. err := ValidateUser(user)
  880. if err != nil {
  881. return err
  882. }
  883. permissions, err := user.GetPermissionsAsJSON()
  884. if err != nil {
  885. return err
  886. }
  887. publicKeys, err := user.GetPublicKeysAsJSON()
  888. if err != nil {
  889. return err
  890. }
  891. filters, err := user.GetFiltersAsJSON()
  892. if err != nil {
  893. return err
  894. }
  895. fsConfig, err := user.GetFsConfigAsJSON()
  896. if err != nil {
  897. return err
  898. }
  899. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  900. defer cancel()
  901. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  902. q := getUpdateUserQuery()
  903. _, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
  904. user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
  905. user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
  906. util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
  907. user.ID)
  908. if err != nil {
  909. return err
  910. }
  911. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  912. return err
  913. }
  914. return generateUserGroupMapping(ctx, user, tx)
  915. })
  916. }
  917. func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
  918. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  919. defer cancel()
  920. q := getDeleteUserQuery(softDelete)
  921. if softDelete {
  922. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  923. if err := sqlCommonClearUserFolderMapping(ctx, &user, tx); err != nil {
  924. return err
  925. }
  926. if err := sqlCommonClearUserGroupMapping(ctx, &user, tx); err != nil {
  927. return err
  928. }
  929. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  930. res, err := tx.ExecContext(ctx, q, ts, ts, user.Username)
  931. if err != nil {
  932. return err
  933. }
  934. return sqlCommonRequireRowAffected(res)
  935. })
  936. }
  937. res, err := dbHandle.ExecContext(ctx, q, user.ID)
  938. if err != nil {
  939. return err
  940. }
  941. return sqlCommonRequireRowAffected(res)
  942. }
  943. func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
  944. users := make([]User, 0, 100)
  945. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  946. defer cancel()
  947. q := getDumpUsersQuery()
  948. rows, err := dbHandle.QueryContext(ctx, q)
  949. if err != nil {
  950. return users, err
  951. }
  952. defer rows.Close()
  953. for rows.Next() {
  954. u, err := getUserFromDbRow(rows)
  955. if err != nil {
  956. return users, err
  957. }
  958. users = append(users, u)
  959. }
  960. err = rows.Err()
  961. if err != nil {
  962. return users, err
  963. }
  964. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  965. if err != nil {
  966. return users, err
  967. }
  968. return getUsersWithGroups(ctx, users, dbHandle)
  969. }
  970. func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) {
  971. users := make([]User, 0, 10)
  972. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  973. defer cancel()
  974. q := getRecentlyUpdatedUsersQuery()
  975. rows, err := dbHandle.QueryContext(ctx, q, after)
  976. if err != nil {
  977. return users, err
  978. }
  979. defer rows.Close()
  980. for rows.Next() {
  981. u, err := getUserFromDbRow(rows)
  982. if err != nil {
  983. return users, err
  984. }
  985. users = append(users, u)
  986. }
  987. err = rows.Err()
  988. if err != nil {
  989. return users, err
  990. }
  991. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  992. if err != nil {
  993. return users, err
  994. }
  995. users, err = getUsersWithGroups(ctx, users, dbHandle)
  996. if err != nil {
  997. return users, err
  998. }
  999. var groupNames []string
  1000. for _, u := range users {
  1001. for _, g := range u.Groups {
  1002. groupNames = append(groupNames, g.Name)
  1003. }
  1004. }
  1005. groupNames = util.RemoveDuplicates(groupNames, false)
  1006. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  1007. if err != nil {
  1008. return users, err
  1009. }
  1010. if len(groups) == 0 {
  1011. return users, nil
  1012. }
  1013. groupsMapping := make(map[string]Group)
  1014. for idx := range groups {
  1015. groupsMapping[groups[idx].Name] = groups[idx]
  1016. }
  1017. for idx := range users {
  1018. ref := &users[idx]
  1019. ref.applyGroupSettings(groupsMapping)
  1020. }
  1021. return users, nil
  1022. }
  1023. func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
  1024. users := make([]User, 0, 30)
  1025. usernames := make([]string, 0, len(toFetch))
  1026. for k := range toFetch {
  1027. usernames = append(usernames, k)
  1028. }
  1029. maxUsers := 30
  1030. for len(usernames) > 0 {
  1031. if maxUsers > len(usernames) {
  1032. maxUsers = len(usernames)
  1033. }
  1034. usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
  1035. if err != nil {
  1036. return users, err
  1037. }
  1038. users = append(users, usersRange...)
  1039. usernames = usernames[maxUsers:]
  1040. }
  1041. var usersWithFolders []User
  1042. validIdx := 0
  1043. for _, user := range users {
  1044. if toFetch[user.Username] {
  1045. usersWithFolders = append(usersWithFolders, user)
  1046. } else {
  1047. users[validIdx] = user
  1048. validIdx++
  1049. }
  1050. }
  1051. users = users[:validIdx]
  1052. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1053. defer cancel()
  1054. usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
  1055. if err != nil {
  1056. return users, err
  1057. }
  1058. users = append(users, usersWithFolders...)
  1059. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1060. if err != nil {
  1061. return users, err
  1062. }
  1063. var groupNames []string
  1064. for _, u := range users {
  1065. for _, g := range u.Groups {
  1066. groupNames = append(groupNames, g.Name)
  1067. }
  1068. }
  1069. groupNames = util.RemoveDuplicates(groupNames, false)
  1070. if len(groupNames) == 0 {
  1071. return users, nil
  1072. }
  1073. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  1074. if err != nil {
  1075. return users, err
  1076. }
  1077. groupsMapping := make(map[string]Group)
  1078. for idx := range groups {
  1079. groupsMapping[groups[idx].Name] = groups[idx]
  1080. }
  1081. for idx := range users {
  1082. ref := &users[idx]
  1083. ref.applyGroupSettings(groupsMapping)
  1084. }
  1085. return users, nil
  1086. }
  1087. func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
  1088. users := make([]User, 0, len(usernames))
  1089. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1090. defer cancel()
  1091. q := getUsersForQuotaCheckQuery(len(usernames))
  1092. queryArgs := make([]any, 0, len(usernames))
  1093. for idx := range usernames {
  1094. queryArgs = append(queryArgs, usernames[idx])
  1095. }
  1096. rows, err := dbHandle.QueryContext(ctx, q, queryArgs...)
  1097. if err != nil {
  1098. return nil, err
  1099. }
  1100. defer rows.Close()
  1101. for rows.Next() {
  1102. var user User
  1103. var filters sql.NullString
  1104. err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer,
  1105. &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer,
  1106. &user.UsedDownloadDataTransfer, &filters)
  1107. if err != nil {
  1108. return users, err
  1109. }
  1110. if filters.Valid {
  1111. var userFilters UserFilters
  1112. err = json.Unmarshal([]byte(filters.String), &userFilters)
  1113. if err == nil {
  1114. user.Filters = userFilters
  1115. }
  1116. }
  1117. users = append(users, user)
  1118. }
  1119. return users, rows.Err()
  1120. }
  1121. func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error {
  1122. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1123. defer cancel()
  1124. q := getAddActiveTransferQuery()
  1125. now := util.GetTimeAsMsSinceEpoch(time.Now())
  1126. _, err := dbHandle.ExecContext(ctx, q, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username,
  1127. transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize,
  1128. now, now)
  1129. return err
  1130. }
  1131. func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error {
  1132. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1133. defer cancel()
  1134. q := getUpdateActiveTransferSizesQuery()
  1135. _, err := dbHandle.ExecContext(ctx, q, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID)
  1136. return err
  1137. }
  1138. func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error {
  1139. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1140. defer cancel()
  1141. q := getRemoveActiveTransferQuery()
  1142. _, err := dbHandle.ExecContext(ctx, q, connectionID, transferID)
  1143. return err
  1144. }
  1145. func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error {
  1146. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1147. defer cancel()
  1148. q := getCleanupActiveTransfersQuery()
  1149. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(before))
  1150. return err
  1151. }
  1152. func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) {
  1153. transfers := make([]ActiveTransfer, 0, 30)
  1154. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1155. defer cancel()
  1156. q := getActiveTransfersQuery()
  1157. rows, err := dbHandle.QueryContext(ctx, q, util.GetTimeAsMsSinceEpoch(from))
  1158. if err != nil {
  1159. return nil, err
  1160. }
  1161. defer rows.Close()
  1162. for rows.Next() {
  1163. var transfer ActiveTransfer
  1164. var folderName sql.NullString
  1165. err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP,
  1166. &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt,
  1167. &transfer.UpdatedAt)
  1168. if err != nil {
  1169. return transfers, err
  1170. }
  1171. if folderName.Valid {
  1172. transfer.FolderName = folderName.String
  1173. }
  1174. transfers = append(transfers, transfer)
  1175. }
  1176. return transfers, rows.Err()
  1177. }
  1178. func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
  1179. users := make([]User, 0, limit)
  1180. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1181. defer cancel()
  1182. q := getUsersQuery(order)
  1183. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  1184. if err != nil {
  1185. return users, err
  1186. }
  1187. defer rows.Close()
  1188. for rows.Next() {
  1189. u, err := getUserFromDbRow(rows)
  1190. if err != nil {
  1191. return users, err
  1192. }
  1193. users = append(users, u)
  1194. }
  1195. err = rows.Err()
  1196. if err != nil {
  1197. return users, err
  1198. }
  1199. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  1200. if err != nil {
  1201. return users, err
  1202. }
  1203. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1204. if err != nil {
  1205. return users, err
  1206. }
  1207. for idx := range users {
  1208. users[idx].PrepareForRendering()
  1209. }
  1210. return users, nil
  1211. }
  1212. func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) {
  1213. hosts := make([]DefenderEntry, 0, 100)
  1214. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1215. defer cancel()
  1216. q := getDefenderHostsQuery()
  1217. rows, err := dbHandle.QueryContext(ctx, q, from, limit)
  1218. if err != nil {
  1219. providerLog(logger.LevelError, "unable to get defender hosts: %v", err)
  1220. return hosts, err
  1221. }
  1222. defer rows.Close()
  1223. var idForScores []int64
  1224. for rows.Next() {
  1225. var banTime sql.NullInt64
  1226. host := DefenderEntry{}
  1227. err = rows.Scan(&host.ID, &host.IP, &banTime)
  1228. if err != nil {
  1229. providerLog(logger.LevelError, "unable to scan defender host row: %v", err)
  1230. return hosts, err
  1231. }
  1232. var hostBanTime time.Time
  1233. if banTime.Valid && banTime.Int64 > 0 {
  1234. hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1235. }
  1236. if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) {
  1237. idForScores = append(idForScores, host.ID)
  1238. } else {
  1239. host.BanTime = hostBanTime
  1240. }
  1241. hosts = append(hosts, host)
  1242. }
  1243. err = rows.Err()
  1244. if err != nil {
  1245. providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err)
  1246. return hosts, err
  1247. }
  1248. return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle)
  1249. }
  1250. func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (DefenderEntry, error) {
  1251. var host DefenderEntry
  1252. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1253. defer cancel()
  1254. q := getDefenderIsHostBannedQuery()
  1255. row := dbHandle.QueryRowContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1256. err := row.Scan(&host.ID)
  1257. if err != nil {
  1258. if errors.Is(err, sql.ErrNoRows) {
  1259. return host, util.NewRecordNotFoundError("host not found")
  1260. }
  1261. providerLog(logger.LevelError, "unable to check ban status for host %#v: %v", ip, err)
  1262. return host, err
  1263. }
  1264. return host, nil
  1265. }
  1266. func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (DefenderEntry, error) {
  1267. var host DefenderEntry
  1268. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1269. defer cancel()
  1270. q := getDefenderHostQuery()
  1271. row := dbHandle.QueryRowContext(ctx, q, ip, from)
  1272. var banTime sql.NullInt64
  1273. err := row.Scan(&host.ID, &host.IP, &banTime)
  1274. if err != nil {
  1275. if errors.Is(err, sql.ErrNoRows) {
  1276. return host, util.NewRecordNotFoundError("host not found")
  1277. }
  1278. providerLog(logger.LevelError, "unable to get host for ip %#v: %v", ip, err)
  1279. return host, err
  1280. }
  1281. if banTime.Valid && banTime.Int64 > 0 {
  1282. hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1283. if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) {
  1284. host.BanTime = hostBanTime
  1285. return host, nil
  1286. }
  1287. }
  1288. hosts, err := getDefenderHostsWithScores(ctx, []DefenderEntry{host}, from, []int64{host.ID}, dbHandle)
  1289. if err != nil {
  1290. return host, err
  1291. }
  1292. if len(hosts) == 0 {
  1293. return host, util.NewRecordNotFoundError("host not found")
  1294. }
  1295. return hosts[0], nil
  1296. }
  1297. func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error {
  1298. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1299. defer cancel()
  1300. q := getDefenderIncrementBanTimeQuery()
  1301. _, err := dbHandle.ExecContext(ctx, q, minutesToAdd*60000, ip)
  1302. if err == nil {
  1303. providerLog(logger.LevelDebug, "ban time updated for ip %#v, increment (minutes): %v",
  1304. ip, minutesToAdd)
  1305. } else {
  1306. providerLog(logger.LevelError, "error updating ban time for ip %#v: %v", ip, err)
  1307. }
  1308. return err
  1309. }
  1310. func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error {
  1311. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1312. defer cancel()
  1313. q := getDefenderSetBanTimeQuery()
  1314. _, err := dbHandle.ExecContext(ctx, q, banTime, ip)
  1315. if err == nil {
  1316. providerLog(logger.LevelDebug, "ip %#v banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime))
  1317. } else {
  1318. providerLog(logger.LevelError, "error setting ban time for ip %#v: %v", ip, err)
  1319. }
  1320. return err
  1321. }
  1322. func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
  1323. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1324. defer cancel()
  1325. q := getDeleteDefenderHostQuery()
  1326. res, err := dbHandle.ExecContext(ctx, q, ip)
  1327. if err != nil {
  1328. providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
  1329. return err
  1330. }
  1331. return sqlCommonRequireRowAffected(res)
  1332. }
  1333. func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
  1334. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1335. defer cancel()
  1336. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1337. if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil {
  1338. return err
  1339. }
  1340. return sqlCommonAddDefenderEvent(ctx, ip, score, tx)
  1341. })
  1342. }
  1343. func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error {
  1344. if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil {
  1345. return err
  1346. }
  1347. return sqlCommonCleanupDefenderHosts(from, dbHandler)
  1348. }
  1349. func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error {
  1350. q := getAddDefenderHostQuery()
  1351. _, err := tx.ExecContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1352. if err != nil {
  1353. providerLog(logger.LevelError, "unable to add defender host %#v: %v", ip, err)
  1354. }
  1355. return err
  1356. }
  1357. func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error {
  1358. q := getAddDefenderEventQuery()
  1359. _, err := tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip)
  1360. if err != nil {
  1361. providerLog(logger.LevelError, "unable to add defender event for %#v: %v", ip, err)
  1362. }
  1363. return err
  1364. }
  1365. func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error {
  1366. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1367. defer cancel()
  1368. q := getDefenderHostsCleanupQuery()
  1369. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), from)
  1370. if err != nil {
  1371. providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err)
  1372. }
  1373. return err
  1374. }
  1375. func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
  1376. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1377. defer cancel()
  1378. q := getDefenderEventsCleanupQuery()
  1379. _, err := dbHandle.ExecContext(ctx, q, from)
  1380. if err != nil {
  1381. providerLog(logger.LevelError, "unable to cleanup defender events: %v", err)
  1382. }
  1383. return err
  1384. }
  1385. func getShareFromDbRow(row sqlScanner) (Share, error) {
  1386. var share Share
  1387. var description, password, allowFrom, paths sql.NullString
  1388. err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope,
  1389. &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt,
  1390. &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens,
  1391. &share.UsedTokens, &allowFrom)
  1392. if err != nil {
  1393. if errors.Is(err, sql.ErrNoRows) {
  1394. return share, util.NewRecordNotFoundError(err.Error())
  1395. }
  1396. return share, err
  1397. }
  1398. if paths.Valid {
  1399. var list []string
  1400. err = json.Unmarshal([]byte(paths.String), &list)
  1401. if err != nil {
  1402. return share, err
  1403. }
  1404. share.Paths = list
  1405. } else {
  1406. return share, errors.New("unable to decode shared paths")
  1407. }
  1408. if description.Valid {
  1409. share.Description = description.String
  1410. }
  1411. if password.Valid {
  1412. share.Password = password.String
  1413. }
  1414. if allowFrom.Valid {
  1415. var list []string
  1416. err = json.Unmarshal([]byte(allowFrom.String), &list)
  1417. if err == nil {
  1418. share.AllowFrom = list
  1419. }
  1420. }
  1421. return share, nil
  1422. }
  1423. func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
  1424. var apiKey APIKey
  1425. var userID, adminID sql.NullInt64
  1426. var description sql.NullString
  1427. err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
  1428. &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
  1429. if err != nil {
  1430. if errors.Is(err, sql.ErrNoRows) {
  1431. return apiKey, util.NewRecordNotFoundError(err.Error())
  1432. }
  1433. return apiKey, err
  1434. }
  1435. if userID.Valid {
  1436. apiKey.userID = userID.Int64
  1437. }
  1438. if adminID.Valid {
  1439. apiKey.adminID = adminID.Int64
  1440. }
  1441. if description.Valid {
  1442. apiKey.Description = description.String
  1443. }
  1444. return apiKey, nil
  1445. }
  1446. func getAdminFromDbRow(row sqlScanner) (Admin, error) {
  1447. var admin Admin
  1448. var email, filters, additionalInfo, permissions, description sql.NullString
  1449. err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions,
  1450. &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin)
  1451. if err != nil {
  1452. if errors.Is(err, sql.ErrNoRows) {
  1453. return admin, util.NewRecordNotFoundError(err.Error())
  1454. }
  1455. return admin, err
  1456. }
  1457. if permissions.Valid {
  1458. var perms []string
  1459. err = json.Unmarshal([]byte(permissions.String), &perms)
  1460. if err != nil {
  1461. return admin, err
  1462. }
  1463. admin.Permissions = perms
  1464. }
  1465. if email.Valid {
  1466. admin.Email = email.String
  1467. }
  1468. if filters.Valid {
  1469. var adminFilters AdminFilters
  1470. err = json.Unmarshal([]byte(filters.String), &adminFilters)
  1471. if err == nil {
  1472. admin.Filters = adminFilters
  1473. }
  1474. }
  1475. if additionalInfo.Valid {
  1476. admin.AdditionalInfo = additionalInfo.String
  1477. }
  1478. if description.Valid {
  1479. admin.Description = description.String
  1480. }
  1481. admin.SetEmptySecretsIfNil()
  1482. return admin, nil
  1483. }
  1484. func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) {
  1485. var action BaseEventAction
  1486. var description sql.NullString
  1487. var options []byte
  1488. err := row.Scan(&action.ID, &action.Name, &description, &action.Type, &options)
  1489. if err != nil {
  1490. if errors.Is(err, sql.ErrNoRows) {
  1491. return action, util.NewRecordNotFoundError(err.Error())
  1492. }
  1493. return action, err
  1494. }
  1495. if description.Valid {
  1496. action.Description = description.String
  1497. }
  1498. if len(options) > 0 {
  1499. err = json.Unmarshal(options, &action.Options)
  1500. if err != nil {
  1501. return action, err
  1502. }
  1503. }
  1504. return action, nil
  1505. }
  1506. func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
  1507. var rule EventRule
  1508. var description sql.NullString
  1509. var conditions []byte
  1510. err := row.Scan(&rule.ID, &rule.Name, &description, &rule.CreatedAt, &rule.UpdatedAt, &rule.Trigger,
  1511. &conditions, &rule.DeletedAt)
  1512. if err != nil {
  1513. if errors.Is(err, sql.ErrNoRows) {
  1514. return rule, util.NewRecordNotFoundError(err.Error())
  1515. }
  1516. return rule, err
  1517. }
  1518. if len(conditions) > 0 {
  1519. err = json.Unmarshal(conditions, &rule.Conditions)
  1520. if err != nil {
  1521. return rule, err
  1522. }
  1523. }
  1524. if description.Valid {
  1525. rule.Description = description.String
  1526. }
  1527. return rule, nil
  1528. }
  1529. func getGroupFromDbRow(row sqlScanner) (Group, error) {
  1530. var group Group
  1531. var userSettings, description sql.NullString
  1532. err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings)
  1533. if err != nil {
  1534. if errors.Is(err, sql.ErrNoRows) {
  1535. return group, util.NewRecordNotFoundError(err.Error())
  1536. }
  1537. return group, err
  1538. }
  1539. if description.Valid {
  1540. group.Description = description.String
  1541. }
  1542. if userSettings.Valid {
  1543. var settings GroupUserSettings
  1544. err = json.Unmarshal([]byte(userSettings.String), &settings)
  1545. if err == nil {
  1546. group.UserSettings = settings
  1547. }
  1548. }
  1549. return group, nil
  1550. }
  1551. func getUserFromDbRow(row sqlScanner) (User, error) {
  1552. var user User
  1553. var permissions sql.NullString
  1554. var password sql.NullString
  1555. var publicKey sql.NullString
  1556. var filters sql.NullString
  1557. var fsConfig sql.NullString
  1558. var additionalInfo, description, email sql.NullString
  1559. err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  1560. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  1561. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  1562. &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer,
  1563. &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt, &user.FirstDownload,
  1564. &user.FirstUpload)
  1565. if err != nil {
  1566. if errors.Is(err, sql.ErrNoRows) {
  1567. return user, util.NewRecordNotFoundError(err.Error())
  1568. }
  1569. return user, err
  1570. }
  1571. if password.Valid {
  1572. user.Password = password.String
  1573. }
  1574. // we can have a empty string or an invalid json in null string
  1575. // so we do a relaxed test if the field is optional, for example we
  1576. // populate public keys only if unmarshal does not return an error
  1577. if publicKey.Valid {
  1578. var list []string
  1579. err = json.Unmarshal([]byte(publicKey.String), &list)
  1580. if err == nil {
  1581. user.PublicKeys = list
  1582. }
  1583. }
  1584. if permissions.Valid {
  1585. perms := make(map[string][]string)
  1586. err = json.Unmarshal([]byte(permissions.String), &perms)
  1587. if err != nil {
  1588. providerLog(logger.LevelError, "unable to deserialize permissions for user %#v: %v", user.Username, err)
  1589. return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
  1590. }
  1591. user.Permissions = perms
  1592. }
  1593. if filters.Valid {
  1594. var userFilters UserFilters
  1595. err = json.Unmarshal([]byte(filters.String), &userFilters)
  1596. if err == nil {
  1597. user.Filters = userFilters
  1598. }
  1599. }
  1600. if fsConfig.Valid {
  1601. var fs vfs.Filesystem
  1602. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1603. if err == nil {
  1604. user.FsConfig = fs
  1605. }
  1606. }
  1607. if additionalInfo.Valid {
  1608. user.AdditionalInfo = additionalInfo.String
  1609. }
  1610. if description.Valid {
  1611. user.Description = description.String
  1612. }
  1613. if email.Valid {
  1614. user.Email = email.String
  1615. }
  1616. user.SetEmptySecretsIfNil()
  1617. return user, nil
  1618. }
  1619. func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1620. var folder vfs.BaseVirtualFolder
  1621. q := getFolderByNameQuery()
  1622. row := dbHandle.QueryRowContext(ctx, q, name)
  1623. var mappedPath, description, fsConfig sql.NullString
  1624. err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
  1625. &folder.Name, &description, &fsConfig)
  1626. if err != nil {
  1627. if errors.Is(err, sql.ErrNoRows) {
  1628. return folder, util.NewRecordNotFoundError(err.Error())
  1629. }
  1630. return folder, err
  1631. }
  1632. if mappedPath.Valid {
  1633. folder.MappedPath = mappedPath.String
  1634. }
  1635. if description.Valid {
  1636. folder.Description = description.String
  1637. }
  1638. if fsConfig.Valid {
  1639. var fs vfs.Filesystem
  1640. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1641. if err == nil {
  1642. folder.FsConfig = fs
  1643. }
  1644. }
  1645. return folder, err
  1646. }
  1647. func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1648. folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
  1649. if err != nil {
  1650. return folder, err
  1651. }
  1652. folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle)
  1653. if err != nil {
  1654. return folder, err
  1655. }
  1656. if len(folders) != 1 {
  1657. return folder, fmt.Errorf("unable to associate users with folder %#v", name)
  1658. }
  1659. folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle)
  1660. if err != nil {
  1661. return folder, err
  1662. }
  1663. if len(folders) != 1 {
  1664. return folder, fmt.Errorf("unable to associate groups with folder %#v", name)
  1665. }
  1666. return folders[0], nil
  1667. }
  1668. func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
  1669. usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier,
  1670. ) error {
  1671. fsConfig, err := json.Marshal(baseFolder.FsConfig)
  1672. if err != nil {
  1673. return err
  1674. }
  1675. q := getUpsertFolderQuery()
  1676. _, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
  1677. lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
  1678. return err
  1679. }
  1680. func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1681. err := ValidateFolder(folder)
  1682. if err != nil {
  1683. return err
  1684. }
  1685. fsConfig, err := json.Marshal(folder.FsConfig)
  1686. if err != nil {
  1687. return err
  1688. }
  1689. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1690. defer cancel()
  1691. q := getAddFolderQuery()
  1692. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
  1693. folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
  1694. return err
  1695. }
  1696. func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1697. err := ValidateFolder(folder)
  1698. if err != nil {
  1699. return err
  1700. }
  1701. fsConfig, err := json.Marshal(folder.FsConfig)
  1702. if err != nil {
  1703. return err
  1704. }
  1705. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1706. defer cancel()
  1707. q := getUpdateFolderQuery()
  1708. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
  1709. return err
  1710. }
  1711. func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1712. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1713. defer cancel()
  1714. q := getDeleteFolderQuery()
  1715. res, err := dbHandle.ExecContext(ctx, q, folder.ID)
  1716. if err != nil {
  1717. return err
  1718. }
  1719. return sqlCommonRequireRowAffected(res)
  1720. }
  1721. func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1722. folders := make([]vfs.BaseVirtualFolder, 0, 50)
  1723. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1724. defer cancel()
  1725. q := getDumpFoldersQuery()
  1726. rows, err := dbHandle.QueryContext(ctx, q)
  1727. if err != nil {
  1728. return folders, err
  1729. }
  1730. defer rows.Close()
  1731. for rows.Next() {
  1732. var folder vfs.BaseVirtualFolder
  1733. var mappedPath, description, fsConfig sql.NullString
  1734. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1735. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1736. if err != nil {
  1737. return folders, err
  1738. }
  1739. if mappedPath.Valid {
  1740. folder.MappedPath = mappedPath.String
  1741. }
  1742. if description.Valid {
  1743. folder.Description = description.String
  1744. }
  1745. if fsConfig.Valid {
  1746. var fs vfs.Filesystem
  1747. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1748. if err == nil {
  1749. folder.FsConfig = fs
  1750. }
  1751. }
  1752. folders = append(folders, folder)
  1753. }
  1754. return folders, rows.Err()
  1755. }
  1756. func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1757. folders := make([]vfs.BaseVirtualFolder, 0, limit)
  1758. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1759. defer cancel()
  1760. q := getFoldersQuery(order, minimal)
  1761. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  1762. if err != nil {
  1763. return folders, err
  1764. }
  1765. defer rows.Close()
  1766. for rows.Next() {
  1767. var folder vfs.BaseVirtualFolder
  1768. if minimal {
  1769. err = rows.Scan(&folder.ID, &folder.Name)
  1770. if err != nil {
  1771. return folders, err
  1772. }
  1773. } else {
  1774. var mappedPath, description, fsConfig sql.NullString
  1775. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1776. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1777. if err != nil {
  1778. return folders, err
  1779. }
  1780. if mappedPath.Valid {
  1781. folder.MappedPath = mappedPath.String
  1782. }
  1783. if description.Valid {
  1784. folder.Description = description.String
  1785. }
  1786. if fsConfig.Valid {
  1787. var fs vfs.Filesystem
  1788. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1789. if err == nil {
  1790. folder.FsConfig = fs
  1791. }
  1792. }
  1793. }
  1794. folder.PrepareForRendering()
  1795. folders = append(folders, folder)
  1796. }
  1797. err = rows.Err()
  1798. if err != nil {
  1799. return folders, err
  1800. }
  1801. if minimal {
  1802. return folders, nil
  1803. }
  1804. folders, err = getVirtualFoldersWithUsers(folders, dbHandle)
  1805. if err != nil {
  1806. return folders, err
  1807. }
  1808. return getVirtualFoldersWithGroups(folders, dbHandle)
  1809. }
  1810. func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1811. q := getClearUserFolderMappingQuery()
  1812. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1813. return err
  1814. }
  1815. func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1816. q := getClearGroupFolderMappingQuery()
  1817. _, err := dbHandle.ExecContext(ctx, q, group.Name)
  1818. return err
  1819. }
  1820. func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1821. q := getClearUserGroupMappingQuery()
  1822. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1823. return err
  1824. }
  1825. func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1826. q := getAddUserFolderMappingQuery()
  1827. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username)
  1828. return err
  1829. }
  1830. func sqlCommonClearAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error {
  1831. q := getClearAdminGroupMappingQuery()
  1832. _, err := dbHandle.ExecContext(ctx, q, admin.Username)
  1833. return err
  1834. }
  1835. func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1836. q := getAddGroupFolderMappingQuery()
  1837. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name)
  1838. return err
  1839. }
  1840. func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType int, dbHandle sqlQuerier) error {
  1841. q := getAddUserGroupMappingQuery()
  1842. _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType)
  1843. return err
  1844. }
  1845. func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName string, mappingOptions AdminGroupMappingOptions,
  1846. dbHandle sqlQuerier,
  1847. ) error {
  1848. options, err := json.Marshal(mappingOptions)
  1849. if err != nil {
  1850. return err
  1851. }
  1852. q := getAddAdminGroupMappingQuery()
  1853. _, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options))
  1854. return err
  1855. }
  1856. func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1857. err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle)
  1858. if err != nil {
  1859. return err
  1860. }
  1861. for idx := range group.VirtualFolders {
  1862. vfolder := &group.VirtualFolders[idx]
  1863. err = sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1864. if err != nil {
  1865. return err
  1866. }
  1867. err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
  1868. if err != nil {
  1869. return err
  1870. }
  1871. }
  1872. return err
  1873. }
  1874. func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1875. err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle)
  1876. if err != nil {
  1877. return err
  1878. }
  1879. for idx := range user.VirtualFolders {
  1880. vfolder := &user.VirtualFolders[idx]
  1881. err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1882. if err != nil {
  1883. return err
  1884. }
  1885. err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
  1886. if err != nil {
  1887. return err
  1888. }
  1889. }
  1890. return err
  1891. }
  1892. func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1893. err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle)
  1894. if err != nil {
  1895. return err
  1896. }
  1897. for _, group := range user.Groups {
  1898. err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, dbHandle)
  1899. if err != nil {
  1900. return err
  1901. }
  1902. }
  1903. return err
  1904. }
  1905. func generateAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error {
  1906. err := sqlCommonClearAdminGroupMapping(ctx, admin, dbHandle)
  1907. if err != nil {
  1908. return err
  1909. }
  1910. for _, group := range admin.Groups {
  1911. err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, dbHandle)
  1912. if err != nil {
  1913. return err
  1914. }
  1915. }
  1916. return err
  1917. }
  1918. func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64,
  1919. dbHandle sqlQuerier) (
  1920. []DefenderEntry,
  1921. error,
  1922. ) {
  1923. if len(idForScores) == 0 {
  1924. return hosts, nil
  1925. }
  1926. hostsWithScores := make(map[int64]int)
  1927. q := getDefenderEventsQuery(idForScores)
  1928. rows, err := dbHandle.QueryContext(ctx, q, from)
  1929. if err != nil {
  1930. providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err)
  1931. return nil, err
  1932. }
  1933. defer rows.Close()
  1934. for rows.Next() {
  1935. var hostID int64
  1936. var score int
  1937. err = rows.Scan(&hostID, &score)
  1938. if err != nil {
  1939. providerLog(logger.LevelError, "error scanning host score row: %v", err)
  1940. return hosts, err
  1941. }
  1942. if score > 0 {
  1943. hostsWithScores[hostID] = score
  1944. }
  1945. }
  1946. err = rows.Err()
  1947. if err != nil {
  1948. return hosts, err
  1949. }
  1950. result := make([]DefenderEntry, 0, len(hosts))
  1951. for idx := range hosts {
  1952. hosts[idx].Score = hostsWithScores[hosts[idx].ID]
  1953. if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() {
  1954. result = append(result, hosts[idx])
  1955. }
  1956. }
  1957. return result, nil
  1958. }
  1959. func getAdminWithGroups(ctx context.Context, admin Admin, dbHandle sqlQuerier) (Admin, error) {
  1960. admins, err := getAdminsWithGroups(ctx, []Admin{admin}, dbHandle)
  1961. if err != nil {
  1962. return admin, err
  1963. }
  1964. if len(admins) == 0 {
  1965. return admin, errSQLGroupsAssociation
  1966. }
  1967. return admins[0], err
  1968. }
  1969. func getAdminsWithGroups(ctx context.Context, admins []Admin, dbHandle sqlQuerier) ([]Admin, error) {
  1970. if len(admins) == 0 {
  1971. return admins, nil
  1972. }
  1973. adminsGroups := make(map[int64][]AdminGroupMapping)
  1974. q := getRelatedGroupsForAdminsQuery(admins)
  1975. rows, err := dbHandle.QueryContext(ctx, q)
  1976. if err != nil {
  1977. return nil, err
  1978. }
  1979. defer rows.Close()
  1980. for rows.Next() {
  1981. var group AdminGroupMapping
  1982. var adminID int64
  1983. var options []byte
  1984. err = rows.Scan(&group.Name, &options, &adminID)
  1985. if err != nil {
  1986. return admins, err
  1987. }
  1988. err = json.Unmarshal(options, &group.Options)
  1989. if err != nil {
  1990. return admins, err
  1991. }
  1992. adminsGroups[adminID] = append(adminsGroups[adminID], group)
  1993. }
  1994. err = rows.Err()
  1995. if err != nil {
  1996. return admins, err
  1997. }
  1998. if len(adminsGroups) == 0 {
  1999. return admins, err
  2000. }
  2001. for idx := range admins {
  2002. ref := &admins[idx]
  2003. ref.Groups = adminsGroups[ref.ID]
  2004. }
  2005. return admins, err
  2006. }
  2007. func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  2008. users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
  2009. if err != nil {
  2010. return user, err
  2011. }
  2012. if len(users) == 0 {
  2013. return user, errSQLFoldersAssociation
  2014. }
  2015. return users[0], err
  2016. }
  2017. func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  2018. if len(users) == 0 {
  2019. return users, nil
  2020. }
  2021. usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  2022. q := getRelatedFoldersForUsersQuery(users)
  2023. rows, err := dbHandle.QueryContext(ctx, q)
  2024. if err != nil {
  2025. return nil, err
  2026. }
  2027. defer rows.Close()
  2028. for rows.Next() {
  2029. var folder vfs.VirtualFolder
  2030. var userID int64
  2031. var mappedPath, fsConfig, description sql.NullString
  2032. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  2033. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
  2034. &description)
  2035. if err != nil {
  2036. return users, err
  2037. }
  2038. if mappedPath.Valid {
  2039. folder.MappedPath = mappedPath.String
  2040. }
  2041. if description.Valid {
  2042. folder.Description = description.String
  2043. }
  2044. if fsConfig.Valid {
  2045. var fs vfs.Filesystem
  2046. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  2047. if err == nil {
  2048. folder.FsConfig = fs
  2049. }
  2050. }
  2051. usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
  2052. }
  2053. err = rows.Err()
  2054. if err != nil {
  2055. return users, err
  2056. }
  2057. if len(usersVirtualFolders) == 0 {
  2058. return users, err
  2059. }
  2060. for idx := range users {
  2061. ref := &users[idx]
  2062. ref.VirtualFolders = usersVirtualFolders[ref.ID]
  2063. }
  2064. return users, err
  2065. }
  2066. func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  2067. users, err := getUsersWithGroups(ctx, []User{user}, dbHandle)
  2068. if err != nil {
  2069. return user, err
  2070. }
  2071. if len(users) == 0 {
  2072. return user, errSQLGroupsAssociation
  2073. }
  2074. return users[0], err
  2075. }
  2076. func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  2077. if len(users) == 0 {
  2078. return users, nil
  2079. }
  2080. usersGroups := make(map[int64][]sdk.GroupMapping)
  2081. q := getRelatedGroupsForUsersQuery(users)
  2082. rows, err := dbHandle.QueryContext(ctx, q)
  2083. if err != nil {
  2084. return nil, err
  2085. }
  2086. defer rows.Close()
  2087. for rows.Next() {
  2088. var group sdk.GroupMapping
  2089. var userID int64
  2090. err = rows.Scan(&group.Name, &group.Type, &userID)
  2091. if err != nil {
  2092. return users, err
  2093. }
  2094. usersGroups[userID] = append(usersGroups[userID], group)
  2095. }
  2096. err = rows.Err()
  2097. if err != nil {
  2098. return users, err
  2099. }
  2100. if len(usersGroups) == 0 {
  2101. return users, err
  2102. }
  2103. for idx := range users {
  2104. ref := &users[idx]
  2105. ref.Groups = usersGroups[ref.ID]
  2106. }
  2107. return users, err
  2108. }
  2109. func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2110. groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle)
  2111. if err != nil {
  2112. return group, err
  2113. }
  2114. if len(groups) == 0 {
  2115. return group, errSQLUsersAssociation
  2116. }
  2117. return groups[0], err
  2118. }
  2119. func getGroupWithAdmins(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2120. groups, err := getGroupsWithAdmins(ctx, []Group{group}, dbHandle)
  2121. if err != nil {
  2122. return group, err
  2123. }
  2124. if len(groups) == 0 {
  2125. return group, errSQLUsersAssociation
  2126. }
  2127. return groups[0], err
  2128. }
  2129. func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2130. groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle)
  2131. if err != nil {
  2132. return group, err
  2133. }
  2134. if len(groups) == 0 {
  2135. return group, errSQLFoldersAssociation
  2136. }
  2137. return groups[0], err
  2138. }
  2139. func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2140. if len(groups) == 0 {
  2141. return groups, nil
  2142. }
  2143. q := getRelatedFoldersForGroupsQuery(groups)
  2144. rows, err := dbHandle.QueryContext(ctx, q)
  2145. if err != nil {
  2146. return nil, err
  2147. }
  2148. defer rows.Close()
  2149. groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  2150. for rows.Next() {
  2151. var groupID int64
  2152. var folder vfs.VirtualFolder
  2153. var mappedPath, fsConfig, description sql.NullString
  2154. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  2155. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig,
  2156. &description)
  2157. if err != nil {
  2158. return groups, err
  2159. }
  2160. if mappedPath.Valid {
  2161. folder.MappedPath = mappedPath.String
  2162. }
  2163. if description.Valid {
  2164. folder.Description = description.String
  2165. }
  2166. if fsConfig.Valid {
  2167. var fs vfs.Filesystem
  2168. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  2169. if err == nil {
  2170. folder.FsConfig = fs
  2171. }
  2172. }
  2173. groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder)
  2174. }
  2175. err = rows.Err()
  2176. if err != nil {
  2177. return groups, err
  2178. }
  2179. if len(groupsVirtualFolders) == 0 {
  2180. return groups, err
  2181. }
  2182. for idx := range groups {
  2183. ref := &groups[idx]
  2184. ref.VirtualFolders = groupsVirtualFolders[ref.ID]
  2185. }
  2186. return groups, err
  2187. }
  2188. func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2189. if len(groups) == 0 {
  2190. return groups, nil
  2191. }
  2192. q := getRelatedUsersForGroupsQuery(groups)
  2193. rows, err := dbHandle.QueryContext(ctx, q)
  2194. if err != nil {
  2195. return nil, err
  2196. }
  2197. defer rows.Close()
  2198. groupsUsers := make(map[int64][]string)
  2199. for rows.Next() {
  2200. var username string
  2201. var groupID int64
  2202. err = rows.Scan(&groupID, &username)
  2203. if err != nil {
  2204. return groups, err
  2205. }
  2206. groupsUsers[groupID] = append(groupsUsers[groupID], username)
  2207. }
  2208. err = rows.Err()
  2209. if err != nil {
  2210. return groups, err
  2211. }
  2212. if len(groupsUsers) == 0 {
  2213. return groups, err
  2214. }
  2215. for idx := range groups {
  2216. ref := &groups[idx]
  2217. ref.Users = groupsUsers[ref.ID]
  2218. }
  2219. return groups, err
  2220. }
  2221. func getGroupsWithAdmins(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2222. if len(groups) == 0 {
  2223. return groups, nil
  2224. }
  2225. q := getRelatedAdminsForGroupsQuery(groups)
  2226. rows, err := dbHandle.QueryContext(ctx, q)
  2227. if err != nil {
  2228. return nil, err
  2229. }
  2230. defer rows.Close()
  2231. groupsAdmins := make(map[int64][]string)
  2232. for rows.Next() {
  2233. var groupID int64
  2234. var username string
  2235. err = rows.Scan(&groupID, &username)
  2236. if err != nil {
  2237. return groups, err
  2238. }
  2239. groupsAdmins[groupID] = append(groupsAdmins[groupID], username)
  2240. }
  2241. err = rows.Err()
  2242. if err != nil {
  2243. return groups, err
  2244. }
  2245. if len(groupsAdmins) > 0 {
  2246. for idx := range groups {
  2247. ref := &groups[idx]
  2248. ref.Admins = groupsAdmins[ref.ID]
  2249. }
  2250. }
  2251. return groups, nil
  2252. }
  2253. func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2254. if len(folders) == 0 {
  2255. return folders, nil
  2256. }
  2257. vFoldersGroups := make(map[int64][]string)
  2258. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2259. defer cancel()
  2260. q := getRelatedGroupsForFoldersQuery(folders)
  2261. rows, err := dbHandle.QueryContext(ctx, q)
  2262. if err != nil {
  2263. return nil, err
  2264. }
  2265. defer rows.Close()
  2266. for rows.Next() {
  2267. var name string
  2268. var folderID int64
  2269. err = rows.Scan(&folderID, &name)
  2270. if err != nil {
  2271. return folders, err
  2272. }
  2273. vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name)
  2274. }
  2275. err = rows.Err()
  2276. if err != nil {
  2277. return folders, err
  2278. }
  2279. if len(vFoldersGroups) == 0 {
  2280. return folders, err
  2281. }
  2282. for idx := range folders {
  2283. ref := &folders[idx]
  2284. ref.Groups = vFoldersGroups[ref.ID]
  2285. }
  2286. return folders, err
  2287. }
  2288. func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2289. if len(folders) == 0 {
  2290. return folders, nil
  2291. }
  2292. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2293. defer cancel()
  2294. q := getRelatedUsersForFoldersQuery(folders)
  2295. rows, err := dbHandle.QueryContext(ctx, q)
  2296. if err != nil {
  2297. return nil, err
  2298. }
  2299. defer rows.Close()
  2300. vFoldersUsers := make(map[int64][]string)
  2301. for rows.Next() {
  2302. var username string
  2303. var folderID int64
  2304. err = rows.Scan(&folderID, &username)
  2305. if err != nil {
  2306. return folders, err
  2307. }
  2308. vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username)
  2309. }
  2310. err = rows.Err()
  2311. if err != nil {
  2312. return folders, err
  2313. }
  2314. if len(vFoldersUsers) == 0 {
  2315. return folders, err
  2316. }
  2317. for idx := range folders {
  2318. ref := &folders[idx]
  2319. ref.Users = vFoldersUsers[ref.ID]
  2320. }
  2321. return folders, err
  2322. }
  2323. func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  2324. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2325. defer cancel()
  2326. q := getUpdateFolderQuotaQuery(reset)
  2327. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  2328. if err == nil {
  2329. providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
  2330. name, filesAdd, sizeAdd, reset)
  2331. } else {
  2332. providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
  2333. }
  2334. return err
  2335. }
  2336. func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) {
  2337. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2338. defer cancel()
  2339. q := getQuotaFolderQuery()
  2340. var usedFiles int
  2341. var usedSize int64
  2342. err := dbHandle.QueryRowContext(ctx, q, mappedPath).Scan(&usedSize, &usedFiles)
  2343. if err != nil {
  2344. providerLog(logger.LevelError, "error getting quota for folder: %v, error: %v", mappedPath, err)
  2345. return 0, 0, err
  2346. }
  2347. return usedFiles, usedSize, err
  2348. }
  2349. func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
  2350. var apiKeys []APIKey
  2351. var err error
  2352. scope := APIKeyScopeAdmin
  2353. if apiKey.userID > 0 {
  2354. scope = APIKeyScopeUser
  2355. }
  2356. apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
  2357. if err != nil {
  2358. return apiKey, err
  2359. }
  2360. if len(apiKeys) > 0 {
  2361. apiKey = apiKeys[0]
  2362. }
  2363. return apiKey, nil
  2364. }
  2365. func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
  2366. if len(apiKeys) == 0 {
  2367. return apiKeys, nil
  2368. }
  2369. values := make(map[int64]string)
  2370. var q string
  2371. if scope == APIKeyScopeUser {
  2372. q = getRelatedUsersForAPIKeysQuery(apiKeys)
  2373. } else {
  2374. q = getRelatedAdminsForAPIKeysQuery(apiKeys)
  2375. }
  2376. rows, err := dbHandle.QueryContext(ctx, q)
  2377. if err != nil {
  2378. return nil, err
  2379. }
  2380. defer rows.Close()
  2381. for rows.Next() {
  2382. var valueID int64
  2383. var valueName string
  2384. err = rows.Scan(&valueID, &valueName)
  2385. if err != nil {
  2386. return apiKeys, err
  2387. }
  2388. values[valueID] = valueName
  2389. }
  2390. err = rows.Err()
  2391. if err != nil {
  2392. return apiKeys, err
  2393. }
  2394. if len(values) == 0 {
  2395. return apiKeys, nil
  2396. }
  2397. for idx := range apiKeys {
  2398. ref := &apiKeys[idx]
  2399. if scope == APIKeyScopeUser {
  2400. ref.User = values[ref.userID]
  2401. } else {
  2402. ref.Admin = values[ref.adminID]
  2403. }
  2404. }
  2405. return apiKeys, nil
  2406. }
  2407. func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
  2408. var userID, adminID sql.NullInt64
  2409. if apiKey.User != "" {
  2410. u, err := provider.userExists(apiKey.User)
  2411. if err != nil {
  2412. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
  2413. }
  2414. userID.Valid = true
  2415. userID.Int64 = u.ID
  2416. }
  2417. if apiKey.Admin != "" {
  2418. a, err := provider.adminExists(apiKey.Admin)
  2419. if err != nil {
  2420. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
  2421. }
  2422. adminID.Valid = true
  2423. adminID.Int64 = a.ID
  2424. }
  2425. return userID, adminID, nil
  2426. }
  2427. func sqlCommonAddSession(session Session, dbHandle *sql.DB) error {
  2428. if err := session.validate(); err != nil {
  2429. return err
  2430. }
  2431. data, err := json.Marshal(session.Data)
  2432. if err != nil {
  2433. return err
  2434. }
  2435. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2436. defer cancel()
  2437. q := getAddSessionQuery()
  2438. _, err = dbHandle.ExecContext(ctx, q, session.Key, data, session.Type, session.Timestamp)
  2439. return err
  2440. }
  2441. func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) {
  2442. var session Session
  2443. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2444. defer cancel()
  2445. q := getSessionQuery()
  2446. var data []byte // type hint, some driver will use string instead of []byte if the type is any
  2447. err := dbHandle.QueryRowContext(ctx, q, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp)
  2448. if err != nil {
  2449. return session, err
  2450. }
  2451. session.Data = data
  2452. return session, nil
  2453. }
  2454. func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error {
  2455. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2456. defer cancel()
  2457. q := getDeleteSessionQuery()
  2458. res, err := dbHandle.ExecContext(ctx, q, key)
  2459. if err != nil {
  2460. return err
  2461. }
  2462. return sqlCommonRequireRowAffected(res)
  2463. }
  2464. func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error {
  2465. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2466. defer cancel()
  2467. q := getCleanupSessionsQuery()
  2468. _, err := dbHandle.ExecContext(ctx, q, sessionType, before)
  2469. return err
  2470. }
  2471. func getActionsWithRuleNames(ctx context.Context, actions []BaseEventAction, dbHandle sqlQuerier,
  2472. ) ([]BaseEventAction, error) {
  2473. if len(actions) == 0 {
  2474. return actions, nil
  2475. }
  2476. q := getRelatedRulesForActionsQuery(actions)
  2477. rows, err := dbHandle.QueryContext(ctx, q)
  2478. if err != nil {
  2479. return nil, err
  2480. }
  2481. defer rows.Close()
  2482. actionsRules := make(map[int64][]string)
  2483. for rows.Next() {
  2484. var name string
  2485. var actionID int64
  2486. if err = rows.Scan(&actionID, &name); err != nil {
  2487. return nil, err
  2488. }
  2489. actionsRules[actionID] = append(actionsRules[actionID], name)
  2490. }
  2491. err = rows.Err()
  2492. if err != nil {
  2493. return nil, err
  2494. }
  2495. if len(actionsRules) == 0 {
  2496. return actions, nil
  2497. }
  2498. for idx := range actions {
  2499. ref := &actions[idx]
  2500. ref.Rules = actionsRules[ref.ID]
  2501. }
  2502. return actions, nil
  2503. }
  2504. func getRulesWithActions(ctx context.Context, rules []EventRule, dbHandle sqlQuerier) ([]EventRule, error) {
  2505. if len(rules) == 0 {
  2506. return rules, nil
  2507. }
  2508. rulesActions := make(map[int64][]EventAction)
  2509. q := getRelatedActionsForRulesQuery(rules)
  2510. rows, err := dbHandle.QueryContext(ctx, q)
  2511. if err != nil {
  2512. return nil, err
  2513. }
  2514. defer rows.Close()
  2515. for rows.Next() {
  2516. var action EventAction
  2517. var ruleID int64
  2518. var description sql.NullString
  2519. var baseOptions, options []byte
  2520. err = rows.Scan(&action.ID, &action.Name, &description, &action.Type, &baseOptions, &options,
  2521. &action.Order, &ruleID)
  2522. if err != nil {
  2523. return rules, err
  2524. }
  2525. if len(baseOptions) > 0 {
  2526. err = json.Unmarshal(baseOptions, &action.BaseEventAction.Options)
  2527. if err != nil {
  2528. return rules, err
  2529. }
  2530. }
  2531. if len(options) > 0 {
  2532. err = json.Unmarshal(options, &action.Options)
  2533. if err != nil {
  2534. return rules, err
  2535. }
  2536. }
  2537. action.BaseEventAction.Options.SetEmptySecretsIfNil()
  2538. rulesActions[ruleID] = append(rulesActions[ruleID], action)
  2539. }
  2540. err = rows.Err()
  2541. if err != nil {
  2542. return rules, err
  2543. }
  2544. if len(rulesActions) == 0 {
  2545. return rules, nil
  2546. }
  2547. for idx := range rules {
  2548. ref := &rules[idx]
  2549. ref.Actions = rulesActions[ref.ID]
  2550. }
  2551. return rules, nil
  2552. }
  2553. func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHandle sqlQuerier) error {
  2554. q := getClearRuleActionMappingQuery()
  2555. _, err := dbHandle.ExecContext(ctx, q, rule.Name)
  2556. if err != nil {
  2557. return err
  2558. }
  2559. for _, action := range rule.Actions {
  2560. options, err := json.Marshal(action.Options)
  2561. if err != nil {
  2562. return err
  2563. }
  2564. q = getAddRuleActionMappingQuery()
  2565. _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
  2566. if err != nil {
  2567. return err
  2568. }
  2569. }
  2570. return nil
  2571. }
  2572. func sqlCommonGetEventActionByName(name string, dbHandle sqlQuerier) (BaseEventAction, error) {
  2573. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2574. defer cancel()
  2575. q := getEventActionByNameQuery()
  2576. row := dbHandle.QueryRowContext(ctx, q, name)
  2577. action, err := getEventActionFromDbRow(row)
  2578. if err != nil {
  2579. return action, err
  2580. }
  2581. actions, err := getActionsWithRuleNames(ctx, []BaseEventAction{action}, dbHandle)
  2582. if err != nil {
  2583. return action, err
  2584. }
  2585. if len(actions) != 1 {
  2586. return action, fmt.Errorf("unable to associate rules with action %q", name)
  2587. }
  2588. return actions[0], nil
  2589. }
  2590. func sqlCommonDumpEventActions(dbHandle sqlQuerier) ([]BaseEventAction, error) {
  2591. actions := make([]BaseEventAction, 0, 10)
  2592. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2593. defer cancel()
  2594. q := getDumpEventActionsQuery()
  2595. rows, err := dbHandle.QueryContext(ctx, q)
  2596. if err != nil {
  2597. return actions, err
  2598. }
  2599. defer rows.Close()
  2600. for rows.Next() {
  2601. action, err := getEventActionFromDbRow(rows)
  2602. if err != nil {
  2603. return actions, err
  2604. }
  2605. actions = append(actions, action)
  2606. }
  2607. return actions, rows.Err()
  2608. }
  2609. func sqlCommonGetEventActions(limit int, offset int, order string, minimal bool,
  2610. dbHandle sqlQuerier,
  2611. ) ([]BaseEventAction, error) {
  2612. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2613. defer cancel()
  2614. q := getEventsActionsQuery(order, minimal)
  2615. actions := make([]BaseEventAction, 0, limit)
  2616. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2617. if err != nil {
  2618. return actions, err
  2619. }
  2620. defer rows.Close()
  2621. for rows.Next() {
  2622. var action BaseEventAction
  2623. if minimal {
  2624. err = rows.Scan(&action.ID, &action.Name)
  2625. } else {
  2626. action, err = getEventActionFromDbRow(rows)
  2627. }
  2628. if err != nil {
  2629. return actions, err
  2630. }
  2631. actions = append(actions, action)
  2632. }
  2633. err = rows.Err()
  2634. if err != nil {
  2635. return nil, err
  2636. }
  2637. if minimal {
  2638. return actions, nil
  2639. }
  2640. actions, err = getActionsWithRuleNames(ctx, actions, dbHandle)
  2641. if err != nil {
  2642. return nil, err
  2643. }
  2644. for idx := range actions {
  2645. actions[idx].PrepareForRendering()
  2646. }
  2647. return actions, nil
  2648. }
  2649. func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2650. if err := action.validate(); err != nil {
  2651. return err
  2652. }
  2653. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2654. defer cancel()
  2655. q := getAddEventActionQuery()
  2656. options, err := json.Marshal(action.Options)
  2657. if err != nil {
  2658. return err
  2659. }
  2660. _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
  2661. return err
  2662. }
  2663. func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2664. if err := action.validate(); err != nil {
  2665. return err
  2666. }
  2667. options, err := json.Marshal(action.Options)
  2668. if err != nil {
  2669. return err
  2670. }
  2671. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2672. defer cancel()
  2673. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2674. q := getUpdateEventActionQuery()
  2675. _, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
  2676. if err != nil {
  2677. return err
  2678. }
  2679. q = getUpdateRulesTimestampQuery()
  2680. _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
  2681. return err
  2682. })
  2683. }
  2684. func sqlCommonDeleteEventAction(action BaseEventAction, dbHandle *sql.DB) error {
  2685. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2686. defer cancel()
  2687. q := getDeleteEventActionQuery()
  2688. res, err := dbHandle.ExecContext(ctx, q, action.Name)
  2689. if err != nil {
  2690. return err
  2691. }
  2692. return sqlCommonRequireRowAffected(res)
  2693. }
  2694. func sqlCommonGetEventRuleByName(name string, dbHandle sqlQuerier) (EventRule, error) {
  2695. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2696. defer cancel()
  2697. q := getEventRulesByNameQuery()
  2698. row := dbHandle.QueryRowContext(ctx, q, name)
  2699. rule, err := getEventRuleFromDbRow(row)
  2700. if err != nil {
  2701. return rule, err
  2702. }
  2703. rules, err := getRulesWithActions(ctx, []EventRule{rule}, dbHandle)
  2704. if err != nil {
  2705. return rule, err
  2706. }
  2707. if len(rules) != 1 {
  2708. return rule, fmt.Errorf("unable to associate rule %q with actions", name)
  2709. }
  2710. return rules[0], nil
  2711. }
  2712. func sqlCommonDumpEventRules(dbHandle sqlQuerier) ([]EventRule, error) {
  2713. rules := make([]EventRule, 0, 10)
  2714. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2715. defer cancel()
  2716. q := getDumpEventRulesQuery()
  2717. rows, err := dbHandle.QueryContext(ctx, q)
  2718. if err != nil {
  2719. return rules, err
  2720. }
  2721. defer rows.Close()
  2722. for rows.Next() {
  2723. rule, err := getEventRuleFromDbRow(rows)
  2724. if err != nil {
  2725. return rules, err
  2726. }
  2727. rules = append(rules, rule)
  2728. }
  2729. err = rows.Err()
  2730. if err != nil {
  2731. return rules, err
  2732. }
  2733. return getRulesWithActions(ctx, rules, dbHandle)
  2734. }
  2735. func sqlCommonGetRecentlyUpdatedRules(after int64, dbHandle sqlQuerier) ([]EventRule, error) {
  2736. rules := make([]EventRule, 0, 10)
  2737. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2738. defer cancel()
  2739. q := getRecentlyUpdatedRulesQuery()
  2740. rows, err := dbHandle.QueryContext(ctx, q, after)
  2741. if err != nil {
  2742. return rules, err
  2743. }
  2744. defer rows.Close()
  2745. for rows.Next() {
  2746. rule, err := getEventRuleFromDbRow(rows)
  2747. if err != nil {
  2748. return rules, err
  2749. }
  2750. rules = append(rules, rule)
  2751. }
  2752. err = rows.Err()
  2753. if err != nil {
  2754. return rules, err
  2755. }
  2756. return getRulesWithActions(ctx, rules, dbHandle)
  2757. }
  2758. func sqlCommonGetEventRules(limit int, offset int, order string, dbHandle sqlQuerier) ([]EventRule, error) {
  2759. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2760. defer cancel()
  2761. q := getEventRulesQuery(order)
  2762. rules := make([]EventRule, 0, limit)
  2763. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2764. if err != nil {
  2765. return rules, err
  2766. }
  2767. defer rows.Close()
  2768. for rows.Next() {
  2769. rule, err := getEventRuleFromDbRow(rows)
  2770. if err != nil {
  2771. return rules, err
  2772. }
  2773. rules = append(rules, rule)
  2774. }
  2775. err = rows.Err()
  2776. if err != nil {
  2777. return rules, err
  2778. }
  2779. rules, err = getRulesWithActions(ctx, rules, dbHandle)
  2780. if err != nil {
  2781. return rules, err
  2782. }
  2783. for idx := range rules {
  2784. rules[idx].PrepareForRendering()
  2785. }
  2786. return rules, nil
  2787. }
  2788. func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
  2789. if err := rule.validate(); err != nil {
  2790. return err
  2791. }
  2792. conditions, err := json.Marshal(rule.Conditions)
  2793. if err != nil {
  2794. return err
  2795. }
  2796. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2797. defer cancel()
  2798. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2799. q := getAddEventRuleQuery()
  2800. _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  2801. util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions))
  2802. if err != nil {
  2803. return err
  2804. }
  2805. return generateEventRuleActionsMapping(ctx, rule, tx)
  2806. })
  2807. }
  2808. func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
  2809. if err := rule.validate(); err != nil {
  2810. return err
  2811. }
  2812. conditions, err := json.Marshal(rule.Conditions)
  2813. if err != nil {
  2814. return err
  2815. }
  2816. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2817. defer cancel()
  2818. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2819. q := getUpdateEventRuleQuery()
  2820. _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  2821. rule.Trigger, string(conditions), rule.Name)
  2822. if err != nil {
  2823. return err
  2824. }
  2825. return generateEventRuleActionsMapping(ctx, rule, tx)
  2826. })
  2827. }
  2828. func sqlCommonDeleteEventRule(rule EventRule, softDelete bool, dbHandle *sql.DB) error {
  2829. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2830. defer cancel()
  2831. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2832. if softDelete {
  2833. q := getClearRuleActionMappingQuery()
  2834. _, err := tx.ExecContext(ctx, q, rule.Name)
  2835. if err != nil {
  2836. return err
  2837. }
  2838. }
  2839. q := getDeleteEventRuleQuery(softDelete)
  2840. if softDelete {
  2841. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  2842. res, err := tx.ExecContext(ctx, q, ts, ts, rule.Name)
  2843. if err != nil {
  2844. return err
  2845. }
  2846. return sqlCommonRequireRowAffected(res)
  2847. }
  2848. res, err := tx.ExecContext(ctx, q, rule.Name)
  2849. if err != nil {
  2850. return err
  2851. }
  2852. if err = sqlCommonRequireRowAffected(res); err != nil {
  2853. return err
  2854. }
  2855. return sqlCommonDeleteTask(rule.Name, tx)
  2856. })
  2857. }
  2858. func sqlCommonGetTaskByName(name string, dbHandle sqlQuerier) (Task, error) {
  2859. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2860. defer cancel()
  2861. task := Task{
  2862. Name: name,
  2863. }
  2864. q := getTaskByNameQuery()
  2865. row := dbHandle.QueryRowContext(ctx, q, name)
  2866. err := row.Scan(&task.UpdateAt, &task.Version)
  2867. if err != nil {
  2868. if errors.Is(err, sql.ErrNoRows) {
  2869. return task, util.NewRecordNotFoundError(err.Error())
  2870. }
  2871. }
  2872. return task, err
  2873. }
  2874. func sqlCommonAddTask(name string, dbHandle *sql.DB) error {
  2875. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2876. defer cancel()
  2877. q := getAddTaskQuery()
  2878. _, err := dbHandle.ExecContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now()))
  2879. return err
  2880. }
  2881. func sqlCommonUpdateTask(name string, version int64, dbHandle *sql.DB) error {
  2882. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2883. defer cancel()
  2884. q := getUpdateTaskQuery()
  2885. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name, version)
  2886. if err != nil {
  2887. return err
  2888. }
  2889. return sqlCommonRequireRowAffected(res)
  2890. }
  2891. func sqlCommonUpdateTaskTimestamp(name string, dbHandle *sql.DB) error {
  2892. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2893. defer cancel()
  2894. q := getUpdateTaskTimestampQuery()
  2895. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  2896. if err != nil {
  2897. return err
  2898. }
  2899. return sqlCommonRequireRowAffected(res)
  2900. }
  2901. func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error {
  2902. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2903. defer cancel()
  2904. q := getDeleteTaskQuery()
  2905. _, err := dbHandle.ExecContext(ctx, q, name)
  2906. return err
  2907. }
  2908. func sqlCommonAddNode(dbHandle *sql.DB) error {
  2909. if err := currentNode.validate(); err != nil {
  2910. return fmt.Errorf("unable to register cluster node: %w", err)
  2911. }
  2912. data, err := json.Marshal(currentNode.Data)
  2913. if err != nil {
  2914. return err
  2915. }
  2916. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2917. defer cancel()
  2918. q := getAddNodeQuery()
  2919. _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
  2920. util.GetTimeAsMsSinceEpoch(time.Now()))
  2921. if err != nil {
  2922. return fmt.Errorf("unable to register cluster node: %w", err)
  2923. }
  2924. providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s",
  2925. currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto)
  2926. return nil
  2927. }
  2928. func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) {
  2929. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2930. defer cancel()
  2931. var data []byte
  2932. var node Node
  2933. q := getNodeByNameQuery()
  2934. row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
  2935. err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
  2936. if err != nil {
  2937. if errors.Is(err, sql.ErrNoRows) {
  2938. return node, util.NewRecordNotFoundError(err.Error())
  2939. }
  2940. return node, err
  2941. }
  2942. err = json.Unmarshal(data, &node.Data)
  2943. return node, err
  2944. }
  2945. func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) {
  2946. var nodes []Node
  2947. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2948. defer cancel()
  2949. q := getNodesQuery()
  2950. rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name,
  2951. util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
  2952. if err != nil {
  2953. return nodes, err
  2954. }
  2955. defer rows.Close()
  2956. for rows.Next() {
  2957. var node Node
  2958. var data []byte
  2959. err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
  2960. if err != nil {
  2961. return nodes, err
  2962. }
  2963. err = json.Unmarshal(data, &node.Data)
  2964. if err != nil {
  2965. return nodes, err
  2966. }
  2967. nodes = append(nodes, node)
  2968. }
  2969. return nodes, rows.Err()
  2970. }
  2971. func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error {
  2972. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2973. defer cancel()
  2974. q := getUpdateNodeTimestampQuery()
  2975. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name)
  2976. if err != nil {
  2977. return err
  2978. }
  2979. return sqlCommonRequireRowAffected(res)
  2980. }
  2981. func sqlCommonCleanupNodes(dbHandle *sql.DB) error {
  2982. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2983. defer cancel()
  2984. q := getCleanupNodesQuery()
  2985. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff)))
  2986. return err
  2987. }
  2988. func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
  2989. var result schemaVersion
  2990. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2991. defer cancel()
  2992. q := getDatabaseVersionQuery()
  2993. stmt, err := dbHandle.PrepareContext(ctx, q)
  2994. if err != nil {
  2995. providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
  2996. if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
  2997. logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
  2998. }
  2999. return result, err
  3000. }
  3001. defer stmt.Close()
  3002. row := stmt.QueryRowContext(ctx)
  3003. err = row.Scan(&result.Version)
  3004. return result, err
  3005. }
  3006. func sqlCommonRequireRowAffected(res sql.Result) error {
  3007. // MariaDB/MySQL returns 0 rows affected for updates that don't change anything
  3008. // so we don't check rows affected for updates
  3009. affected, err := res.RowsAffected()
  3010. if err == nil && affected == 0 {
  3011. return util.NewRecordNotFoundError(sql.ErrNoRows.Error())
  3012. }
  3013. return nil
  3014. }
  3015. func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
  3016. q := getUpdateDBVersionQuery()
  3017. _, err := dbHandle.ExecContext(ctx, q, version)
  3018. return err
  3019. }
  3020. func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error {
  3021. if err := sqlAcquireLock(dbHandle); err != nil {
  3022. return err
  3023. }
  3024. defer sqlReleaseLock(dbHandle)
  3025. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  3026. defer cancel()
  3027. if newVersion > 0 {
  3028. currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false)
  3029. if err == nil {
  3030. if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) {
  3031. providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
  3032. currentVersion.Version, newVersion)
  3033. return nil
  3034. }
  3035. }
  3036. }
  3037. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  3038. for _, q := range sqlQueries {
  3039. if strings.TrimSpace(q) == "" {
  3040. continue
  3041. }
  3042. _, err := tx.ExecContext(ctx, q)
  3043. if err != nil {
  3044. return err
  3045. }
  3046. }
  3047. if newVersion == 0 {
  3048. return nil
  3049. }
  3050. return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
  3051. })
  3052. }
  3053. func sqlAcquireLock(dbHandle *sql.DB) error {
  3054. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  3055. defer cancel()
  3056. switch config.Driver {
  3057. case PGSQLDataProviderName:
  3058. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`)
  3059. if err != nil {
  3060. return fmt.Errorf("unable to get advisory lock: %w", err)
  3061. }
  3062. providerLog(logger.LevelInfo, "acquired database lock")
  3063. case MySQLDataProviderName:
  3064. var lockResult sql.NullInt64
  3065. err := dbHandle.QueryRowContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`).Scan(&lockResult)
  3066. if err != nil {
  3067. return fmt.Errorf("unable to get lock: %w", err)
  3068. }
  3069. if !lockResult.Valid {
  3070. return errors.New("unable to get lock: null value returned")
  3071. }
  3072. if lockResult.Int64 != 1 {
  3073. return fmt.Errorf("unable to get lock, result: %d", lockResult.Int64)
  3074. }
  3075. providerLog(logger.LevelInfo, "acquired database lock")
  3076. }
  3077. return nil
  3078. }
  3079. func sqlReleaseLock(dbHandle *sql.DB) {
  3080. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3081. defer cancel()
  3082. switch config.Driver {
  3083. case PGSQLDataProviderName:
  3084. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`)
  3085. if err != nil {
  3086. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  3087. } else {
  3088. providerLog(logger.LevelInfo, "released database lock")
  3089. }
  3090. case MySQLDataProviderName:
  3091. _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`)
  3092. if err != nil {
  3093. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  3094. } else {
  3095. providerLog(logger.LevelInfo, "released database lock")
  3096. }
  3097. }
  3098. }
  3099. func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
  3100. if config.Driver == CockroachDataProviderName {
  3101. return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
  3102. }
  3103. tx, err := dbHandle.BeginTx(ctx, nil)
  3104. if err != nil {
  3105. return err
  3106. }
  3107. err = txFn(tx)
  3108. if err != nil {
  3109. // we don't change the returned error
  3110. tx.Rollback() //nolint:errcheck
  3111. return err
  3112. }
  3113. return tx.Commit()
  3114. }