sqlcommon.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. package dataprovider
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "time"
  8. "github.com/drakkan/sftpgo/logger"
  9. "github.com/drakkan/sftpgo/utils"
  10. "github.com/drakkan/sftpgo/vfs"
  11. )
  12. const (
  13. sqlDatabaseVersion = 3
  14. initialDBVersionSQL = "INSERT INTO schema_version (version) VALUES (1);"
  15. )
  16. func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
  17. var user User
  18. q := getUserByUsernameQuery()
  19. stmt, err := dbHandle.Prepare(q)
  20. if err != nil {
  21. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  22. return user, err
  23. }
  24. defer stmt.Close()
  25. row := stmt.QueryRow(username)
  26. return getUserFromDbRow(row, nil)
  27. }
  28. func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
  29. var user User
  30. if len(password) == 0 {
  31. return user, errors.New("Credentials cannot be null or empty")
  32. }
  33. user, err := getUserByUsername(username, dbHandle)
  34. if err != nil {
  35. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  36. return user, err
  37. }
  38. return checkUserAndPass(user, password)
  39. }
  40. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
  41. var user User
  42. if len(pubKey) == 0 {
  43. return user, "", errors.New("Credentials cannot be null or empty")
  44. }
  45. user, err := getUserByUsername(username, dbHandle)
  46. if err != nil {
  47. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  48. return user, "", err
  49. }
  50. return checkUserAndPubKey(user, pubKey)
  51. }
  52. func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
  53. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  54. defer cancel()
  55. return dbHandle.PingContext(ctx)
  56. }
  57. func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
  58. var user User
  59. q := getUserByIDQuery()
  60. stmt, err := dbHandle.Prepare(q)
  61. if err != nil {
  62. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  63. return user, err
  64. }
  65. defer stmt.Close()
  66. row := stmt.QueryRow(ID)
  67. return getUserFromDbRow(row, nil)
  68. }
  69. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  70. q := getUpdateQuotaQuery(reset)
  71. stmt, err := dbHandle.Prepare(q)
  72. if err != nil {
  73. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  74. return err
  75. }
  76. defer stmt.Close()
  77. _, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  78. if err == nil {
  79. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  80. username, filesAdd, sizeAdd, reset)
  81. } else {
  82. providerLog(logger.LevelWarn, "error updating quota for user %#v: %v", username, err)
  83. }
  84. return err
  85. }
  86. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  87. q := getUpdateLastLoginQuery()
  88. stmt, err := dbHandle.Prepare(q)
  89. if err != nil {
  90. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  91. return err
  92. }
  93. defer stmt.Close()
  94. _, err = stmt.Exec(utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  95. if err == nil {
  96. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  97. } else {
  98. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  99. }
  100. return err
  101. }
  102. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
  103. q := getQuotaQuery()
  104. stmt, err := dbHandle.Prepare(q)
  105. if err != nil {
  106. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  107. return 0, 0, err
  108. }
  109. defer stmt.Close()
  110. var usedFiles int
  111. var usedSize int64
  112. err = stmt.QueryRow(username).Scan(&usedSize, &usedFiles)
  113. if err != nil {
  114. providerLog(logger.LevelWarn, "error getting quota for user: %v, error: %v", username, err)
  115. return 0, 0, err
  116. }
  117. return usedFiles, usedSize, err
  118. }
  119. func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
  120. var user User
  121. q := getUserByUsernameQuery()
  122. stmt, err := dbHandle.Prepare(q)
  123. if err != nil {
  124. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  125. return user, err
  126. }
  127. defer stmt.Close()
  128. row := stmt.QueryRow(username)
  129. return getUserFromDbRow(row, nil)
  130. }
  131. func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
  132. err := validateUser(&user)
  133. if err != nil {
  134. return err
  135. }
  136. q := getAddUserQuery()
  137. stmt, err := dbHandle.Prepare(q)
  138. if err != nil {
  139. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  140. return err
  141. }
  142. defer stmt.Close()
  143. permissions, err := user.GetPermissionsAsJSON()
  144. if err != nil {
  145. return err
  146. }
  147. publicKeys, err := user.GetPublicKeysAsJSON()
  148. if err != nil {
  149. return err
  150. }
  151. filters, err := user.GetFiltersAsJSON()
  152. if err != nil {
  153. return err
  154. }
  155. fsConfig, err := user.GetFsConfigAsJSON()
  156. if err != nil {
  157. return err
  158. }
  159. virtualFolders, err := user.GetVirtualFoldersAsJSON()
  160. if err != nil {
  161. return err
  162. }
  163. _, err = stmt.Exec(user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  164. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
  165. string(fsConfig), string(virtualFolders))
  166. return err
  167. }
  168. func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
  169. err := validateUser(&user)
  170. if err != nil {
  171. return err
  172. }
  173. q := getUpdateUserQuery()
  174. stmt, err := dbHandle.Prepare(q)
  175. if err != nil {
  176. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  177. return err
  178. }
  179. defer stmt.Close()
  180. permissions, err := user.GetPermissionsAsJSON()
  181. if err != nil {
  182. return err
  183. }
  184. publicKeys, err := user.GetPublicKeysAsJSON()
  185. if err != nil {
  186. return err
  187. }
  188. filters, err := user.GetFiltersAsJSON()
  189. if err != nil {
  190. return err
  191. }
  192. fsConfig, err := user.GetFsConfigAsJSON()
  193. if err != nil {
  194. return err
  195. }
  196. virtualFolders, err := user.GetVirtualFoldersAsJSON()
  197. if err != nil {
  198. return err
  199. }
  200. _, err = stmt.Exec(user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  201. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
  202. string(filters), string(fsConfig), string(virtualFolders), user.ID)
  203. return err
  204. }
  205. func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
  206. q := getDeleteUserQuery()
  207. stmt, err := dbHandle.Prepare(q)
  208. if err != nil {
  209. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  210. return err
  211. }
  212. defer stmt.Close()
  213. _, err = stmt.Exec(user.ID)
  214. return err
  215. }
  216. func sqlCommonDumpUsers(dbHandle *sql.DB) ([]User, error) {
  217. users := []User{}
  218. q := getDumpUsersQuery()
  219. stmt, err := dbHandle.Prepare(q)
  220. if err != nil {
  221. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  222. return nil, err
  223. }
  224. defer stmt.Close()
  225. rows, err := stmt.Query()
  226. if err == nil {
  227. defer rows.Close()
  228. for rows.Next() {
  229. u, err := getUserFromDbRow(nil, rows)
  230. if err != nil {
  231. return users, err
  232. }
  233. err = addCredentialsToUser(&u)
  234. if err != nil {
  235. return users, err
  236. }
  237. users = append(users, u)
  238. }
  239. }
  240. return users, err
  241. }
  242. func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
  243. users := []User{}
  244. q := getUsersQuery(order, username)
  245. stmt, err := dbHandle.Prepare(q)
  246. if err != nil {
  247. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  248. return nil, err
  249. }
  250. defer stmt.Close()
  251. var rows *sql.Rows
  252. if len(username) > 0 {
  253. rows, err = stmt.Query(username, limit, offset) //nolint:rowserrcheck // err is checked
  254. } else {
  255. rows, err = stmt.Query(limit, offset) //nolint:rowserrcheck // err is checked
  256. }
  257. if err == nil {
  258. defer rows.Close()
  259. for rows.Next() {
  260. u, err := getUserFromDbRow(nil, rows)
  261. if err == nil {
  262. users = append(users, HideUserSensitiveData(&u))
  263. } else {
  264. break
  265. }
  266. }
  267. }
  268. return users, err
  269. }
  270. func updateUserPermissionsFromDb(user *User, permissions string) error {
  271. var err error
  272. perms := make(map[string][]string)
  273. err = json.Unmarshal([]byte(permissions), &perms)
  274. if err == nil {
  275. user.Permissions = perms
  276. } else {
  277. // compatibility layer: until version 0.9.4 permissions were a string list
  278. var list []string
  279. err = json.Unmarshal([]byte(permissions), &list)
  280. if err != nil {
  281. return err
  282. }
  283. perms["/"] = list
  284. user.Permissions = perms
  285. }
  286. return err
  287. }
  288. func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
  289. var user User
  290. var permissions sql.NullString
  291. var password sql.NullString
  292. var publicKey sql.NullString
  293. var filters sql.NullString
  294. var fsConfig sql.NullString
  295. var virtualFolders sql.NullString
  296. var err error
  297. if row != nil {
  298. err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  299. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  300. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  301. &virtualFolders)
  302. } else {
  303. err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  304. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  305. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  306. &virtualFolders)
  307. }
  308. if err != nil {
  309. if err == sql.ErrNoRows {
  310. return user, &RecordNotFoundError{err: err.Error()}
  311. }
  312. return user, err
  313. }
  314. if password.Valid {
  315. user.Password = password.String
  316. }
  317. // we can have a empty string or an invalid json in null string
  318. // so we do a relaxed test if the field is optional, for example we
  319. // populate public keys only if unmarshal does not return an error
  320. if publicKey.Valid {
  321. var list []string
  322. err = json.Unmarshal([]byte(publicKey.String), &list)
  323. if err == nil {
  324. user.PublicKeys = list
  325. }
  326. }
  327. if permissions.Valid {
  328. err = updateUserPermissionsFromDb(&user, permissions.String)
  329. if err != nil {
  330. return user, err
  331. }
  332. }
  333. if filters.Valid {
  334. var userFilters UserFilters
  335. err = json.Unmarshal([]byte(filters.String), &userFilters)
  336. if err == nil {
  337. user.Filters = userFilters
  338. }
  339. }
  340. if fsConfig.Valid {
  341. var fs Filesystem
  342. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  343. if err == nil {
  344. user.FsConfig = fs
  345. }
  346. }
  347. if virtualFolders.Valid {
  348. var list []vfs.VirtualFolder
  349. err = json.Unmarshal([]byte(virtualFolders.String), &list)
  350. if err == nil {
  351. user.VirtualFolders = list
  352. }
  353. }
  354. return user, err
  355. }
  356. func sqlCommonRollbackTransaction(tx *sql.Tx) {
  357. err := tx.Rollback()
  358. if err != nil {
  359. providerLog(logger.LevelWarn, "error rolling back transaction: %v", err)
  360. }
  361. }
  362. func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) {
  363. var result schemaVersion
  364. q := getDatabaseVersionQuery()
  365. stmt, err := dbHandle.Prepare(q)
  366. if err != nil {
  367. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  368. return result, err
  369. }
  370. defer stmt.Close()
  371. row := stmt.QueryRow()
  372. err = row.Scan(&result.Version)
  373. return result, err
  374. }
  375. func sqlCommonUpdateDatabaseVersion(dbHandle *sql.DB, version int) error {
  376. q := getUpdateDBVersionQuery()
  377. stmt, err := dbHandle.Prepare(q)
  378. if err != nil {
  379. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  380. return err
  381. }
  382. defer stmt.Close()
  383. _, err = stmt.Exec(version)
  384. return err
  385. }
  386. func sqlCommonUpdateDatabaseVersionWithTX(tx *sql.Tx, version int) error {
  387. q := getUpdateDBVersionQuery()
  388. stmt, err := tx.Prepare(q)
  389. if err != nil {
  390. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  391. return err
  392. }
  393. defer stmt.Close()
  394. _, err = stmt.Exec(version)
  395. return err
  396. }