sqlcommon.go 9.1 KB


  1. package dataprovider
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "time"
  8. "github.com/drakkan/sftpgo/logger"
  9. "github.com/drakkan/sftpgo/utils"
  10. )
  11. func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
  12. var user User
  13. q := getUserByUsernameQuery()
  14. stmt, err := dbHandle.Prepare(q)
  15. if err != nil {
  16. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  17. return user, err
  18. }
  19. defer stmt.Close()
  20. row := stmt.QueryRow(username)
  21. return getUserFromDbRow(row, nil)
  22. }
  23. func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
  24. var user User
  25. if len(password) == 0 {
  26. return user, errors.New("Credentials cannot be null or empty")
  27. }
  28. user, err := getUserByUsername(username, dbHandle)
  29. if err != nil {
  30. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  31. return user, err
  32. }
  33. return checkUserAndPass(user, password)
  34. }
  35. func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, string, error) {
  36. var user User
  37. if len(pubKey) == 0 {
  38. return user, "", errors.New("Credentials cannot be null or empty")
  39. }
  40. user, err := getUserByUsername(username, dbHandle)
  41. if err != nil {
  42. providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
  43. return user, "", err
  44. }
  45. return checkUserAndPubKey(user, pubKey)
  46. }
  47. func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
  48. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  49. defer cancel()
  50. return dbHandle.PingContext(ctx)
  51. }
  52. func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
  53. var user User
  54. q := getUserByIDQuery()
  55. stmt, err := dbHandle.Prepare(q)
  56. if err != nil {
  57. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  58. return user, err
  59. }
  60. defer stmt.Close()
  61. row := stmt.QueryRow(ID)
  62. return getUserFromDbRow(row, nil)
  63. }
  64. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  65. q := getUpdateQuotaQuery(reset)
  66. stmt, err := dbHandle.Prepare(q)
  67. if err != nil {
  68. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  69. return err
  70. }
  71. defer stmt.Close()
  72. _, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  73. if err == nil {
  74. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  75. username, filesAdd, sizeAdd, reset)
  76. } else {
  77. providerLog(logger.LevelWarn, "error updating quota for user %#v: %v", username, err)
  78. }
  79. return err
  80. }
  81. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  82. q := getUpdateLastLoginQuery()
  83. stmt, err := dbHandle.Prepare(q)
  84. if err != nil {
  85. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  86. return err
  87. }
  88. defer stmt.Close()
  89. _, err = stmt.Exec(utils.GetTimeAsMsSinceEpoch(time.Now()), username)
  90. if err == nil {
  91. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  92. } else {
  93. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  94. }
  95. return err
  96. }
  97. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
  98. q := getQuotaQuery()
  99. stmt, err := dbHandle.Prepare(q)
  100. if err != nil {
  101. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  102. return 0, 0, err
  103. }
  104. defer stmt.Close()
  105. var usedFiles int
  106. var usedSize int64
  107. err = stmt.QueryRow(username).Scan(&usedSize, &usedFiles)
  108. if err != nil {
  109. providerLog(logger.LevelWarn, "error getting quota for user: %v, error: %v", username, err)
  110. return 0, 0, err
  111. }
  112. return usedFiles, usedSize, err
  113. }
  114. func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
  115. var user User
  116. q := getUserByUsernameQuery()
  117. stmt, err := dbHandle.Prepare(q)
  118. if err != nil {
  119. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  120. return user, err
  121. }
  122. defer stmt.Close()
  123. row := stmt.QueryRow(username)
  124. return getUserFromDbRow(row, nil)
  125. }
  126. func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
  127. err := validateUser(&user)
  128. if err != nil {
  129. return err
  130. }
  131. q := getAddUserQuery()
  132. stmt, err := dbHandle.Prepare(q)
  133. if err != nil {
  134. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  135. return err
  136. }
  137. defer stmt.Close()
  138. permissions, err := user.GetPermissionsAsJSON()
  139. if err != nil {
  140. return err
  141. }
  142. publicKeys, err := user.GetPublicKeysAsJSON()
  143. if err != nil {
  144. return err
  145. }
  146. filters, err := user.GetFiltersAsJSON()
  147. if err != nil {
  148. return err
  149. }
  150. _, err = stmt.Exec(user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  151. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters))
  152. return err
  153. }
  154. func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
  155. err := validateUser(&user)
  156. if err != nil {
  157. return err
  158. }
  159. q := getUpdateUserQuery()
  160. stmt, err := dbHandle.Prepare(q)
  161. if err != nil {
  162. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  163. return err
  164. }
  165. defer stmt.Close()
  166. permissions, err := user.GetPermissionsAsJSON()
  167. if err != nil {
  168. return err
  169. }
  170. publicKeys, err := user.GetPublicKeysAsJSON()
  171. if err != nil {
  172. return err
  173. }
  174. filters, err := user.GetFiltersAsJSON()
  175. if err != nil {
  176. return err
  177. }
  178. _, err = stmt.Exec(user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  179. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
  180. string(filters), user.ID)
  181. return err
  182. }
  183. func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
  184. q := getDeleteUserQuery()
  185. stmt, err := dbHandle.Prepare(q)
  186. if err != nil {
  187. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  188. return err
  189. }
  190. defer stmt.Close()
  191. _, err = stmt.Exec(user.ID)
  192. return err
  193. }
  194. func sqlCommonDumpUsers(dbHandle *sql.DB) ([]User, error) {
  195. users := []User{}
  196. q := getDumpUsersQuery()
  197. stmt, err := dbHandle.Prepare(q)
  198. if err != nil {
  199. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  200. return nil, err
  201. }
  202. defer stmt.Close()
  203. rows, err := stmt.Query()
  204. if err == nil {
  205. defer rows.Close()
  206. for rows.Next() {
  207. u, err := getUserFromDbRow(nil, rows)
  208. if err == nil {
  209. users = append(users, u)
  210. } else {
  211. break
  212. }
  213. }
  214. }
  215. return users, err
  216. }
  217. func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
  218. users := []User{}
  219. q := getUsersQuery(order, username)
  220. stmt, err := dbHandle.Prepare(q)
  221. if err != nil {
  222. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  223. return nil, err
  224. }
  225. defer stmt.Close()
  226. var rows *sql.Rows
  227. if len(username) > 0 {
  228. rows, err = stmt.Query(username, limit, offset)
  229. } else {
  230. rows, err = stmt.Query(limit, offset)
  231. }
  232. if err == nil {
  233. defer rows.Close()
  234. for rows.Next() {
  235. u, err := getUserFromDbRow(nil, rows)
  236. // hide password
  237. if err == nil {
  238. u.Password = ""
  239. users = append(users, u)
  240. } else {
  241. break
  242. }
  243. }
  244. }
  245. return users, err
  246. }
  247. func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
  248. var user User
  249. var permissions sql.NullString
  250. var password sql.NullString
  251. var publicKey sql.NullString
  252. var filters sql.NullString
  253. var err error
  254. if row != nil {
  255. err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  256. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  257. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters)
  258. } else {
  259. err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  260. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  261. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters)
  262. }
  263. if err != nil {
  264. if err == sql.ErrNoRows {
  265. return user, &RecordNotFoundError{err: err.Error()}
  266. }
  267. return user, err
  268. }
  269. if password.Valid {
  270. user.Password = password.String
  271. }
  272. if publicKey.Valid {
  273. var list []string
  274. err = json.Unmarshal([]byte(publicKey.String), &list)
  275. if err == nil {
  276. user.PublicKeys = list
  277. }
  278. }
  279. if permissions.Valid {
  280. perms := make(map[string][]string)
  281. err = json.Unmarshal([]byte(permissions.String), &perms)
  282. if err == nil {
  283. user.Permissions = perms
  284. } else {
  285. // compatibility layer: until version 0.9.4 permissions were a string list
  286. var list []string
  287. err = json.Unmarshal([]byte(permissions.String), &list)
  288. if err == nil {
  289. perms["/"] = list
  290. user.Permissions = perms
  291. }
  292. }
  293. }
  294. if filters.Valid {
  295. var userFilters UserFilters
  296. err = json.Unmarshal([]byte(filters.String), &userFilters)
  297. if err == nil {
  298. user.Filters = userFilters
  299. }
  300. } else {
  301. user.Filters = UserFilters{
  302. AllowedIP: []string{},
  303. DeniedIP: []string{},
  304. }
  305. }
  306. return user, err
  307. }