sqlcommon.go 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721
  1. package dataprovider
  2. import (
  3. "context"
  4. "crypto/x509"
  5. "database/sql"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "strings"
  10. "time"
  11. "github.com/cockroachdb/cockroach-go/v2/crdb"
  12. "github.com/drakkan/sftpgo/v2/logger"
  13. "github.com/drakkan/sftpgo/v2/sdk"
  14. "github.com/drakkan/sftpgo/v2/util"
  15. "github.com/drakkan/sftpgo/v2/vfs"
  16. )
  17. const (
  18. sqlDatabaseVersion = 14
  19. defaultSQLQueryTimeout = 10 * time.Second
  20. longSQLQueryTimeout = 60 * time.Second
  21. )
  22. var (
  23. errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user")
  24. errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
  25. )
  26. type sqlQuerier interface {
  27. PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
  28. }
  29. type sqlScanner interface {
  30. Scan(dest ...interface{}) error
  31. }
  32. func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) {
  33. var share Share
  34. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  35. defer cancel()
  36. filterUser := username != ""
  37. q := getShareByIDQuery(filterUser)
  38. stmt, err := dbHandle.PrepareContext(ctx, q)
  39. if err != nil {
  40. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  41. return share, err
  42. }
  43. defer stmt.Close()
  44. var row *sql.Row
  45. if filterUser {
  46. row = stmt.QueryRowContext(ctx, shareID, username)
  47. } else {
  48. row = stmt.QueryRowContext(ctx, shareID)
  49. }
  50. return getShareFromDbRow(row)
  51. }
  52. func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
  53. err := share.validate()
  54. if err != nil {
  55. return err
  56. }
  57. user, err := provider.userExists(share.Username)
  58. if err != nil {
  59. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  60. }
  61. paths, err := json.Marshal(share.Paths)
  62. if err != nil {
  63. return err
  64. }
  65. allowFrom := ""
  66. if len(share.AllowFrom) > 0 {
  67. res, err := json.Marshal(share.AllowFrom)
  68. if err == nil {
  69. allowFrom = string(res)
  70. }
  71. }
  72. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  73. defer cancel()
  74. q := getAddShareQuery()
  75. stmt, err := dbHandle.PrepareContext(ctx, q)
  76. if err != nil {
  77. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  78. return err
  79. }
  80. defer stmt.Close()
  81. _, err = stmt.ExecContext(ctx, share.ShareID, share.Name, share.Description, share.Scope,
  82. string(paths), util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
  83. share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, allowFrom, user.ID)
  84. return err
  85. }
  86. func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
  87. err := share.validate()
  88. if err != nil {
  89. return err
  90. }
  91. paths, err := json.Marshal(share.Paths)
  92. if err != nil {
  93. return err
  94. }
  95. allowFrom := ""
  96. if len(share.AllowFrom) > 0 {
  97. res, err := json.Marshal(share.AllowFrom)
  98. if err == nil {
  99. allowFrom = string(res)
  100. }
  101. }
  102. user, err := provider.userExists(share.Username)
  103. if err != nil {
  104. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  105. }
  106. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  107. defer cancel()
  108. q := getUpdateShareQuery()
  109. stmt, err := dbHandle.PrepareContext(ctx, q)
  110. if err != nil {
  111. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  112. return err
  113. }
  114. defer stmt.Close()
  115. _, err = stmt.ExecContext(ctx, share.Name, share.Description, share.Scope, string(paths),
  116. util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
  117. allowFrom, user.ID, share.ShareID)
  118. return err
  119. }
  120. func sqlCommonDeleteShare(share *Share, dbHandle *sql.DB) error {
  121. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  122. defer cancel()
  123. q := getDeleteShareQuery()
  124. stmt, err := dbHandle.PrepareContext(ctx, q)
  125. if err != nil {
  126. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  127. return err
  128. }
  129. defer stmt.Close()
  130. _, err = stmt.ExecContext(ctx, share.ShareID)
  131. return err
  132. }
  133. func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) {
  134. shares := make([]Share, 0, limit)
  135. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  136. defer cancel()
  137. q := getSharesQuery(order)
  138. stmt, err := dbHandle.PrepareContext(ctx, q)
  139. if err != nil {
  140. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  141. return nil, err
  142. }
  143. defer stmt.Close()
  144. rows, err := stmt.QueryContext(ctx, username, limit, offset)
  145. if err != nil {
  146. return shares, err
  147. }
  148. defer rows.Close()
  149. for rows.Next() {
  150. s, err := getShareFromDbRow(rows)
  151. if err != nil {
  152. return shares, err
  153. }
  154. s.HideConfidentialData()
  155. shares = append(shares, s)
  156. }
  157. return shares, rows.Err()
  158. }
  159. func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) {
  160. shares := make([]Share, 0, 30)
  161. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  162. defer cancel()
  163. q := getDumpSharesQuery()
  164. stmt, err := dbHandle.PrepareContext(ctx, q)
  165. if err != nil {
  166. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  167. return nil, err
  168. }
  169. defer stmt.Close()
  170. rows, err := stmt.QueryContext(ctx)
  171. if err != nil {
  172. return shares, err
  173. }
  174. defer rows.Close()
  175. for rows.Next() {
  176. s, err := getShareFromDbRow(rows)
  177. if err != nil {
  178. return shares, err
  179. }
  180. shares = append(shares, s)
  181. }
  182. return shares, rows.Err()
  183. }
  184. func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
  185. var apiKey APIKey
  186. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  187. defer cancel()
  188. q := getAPIKeyByIDQuery()
  189. stmt, err := dbHandle.PrepareContext(ctx, q)
  190. if err != nil {
  191. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  192. return apiKey, err
  193. }
  194. defer stmt.Close()
  195. row := stmt.QueryRowContext(ctx, keyID)
  196. apiKey, err = getAPIKeyFromDbRow(row)
  197. if err != nil {
  198. return apiKey, err
  199. }
  200. return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
  201. }
  202. func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  203. err := apiKey.validate()
  204. if err != nil {
  205. return err
  206. }
  207. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  208. if err != nil {
  209. return err
  210. }
  211. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  212. defer cancel()
  213. q := getAddAPIKeyQuery()
  214. stmt, err := dbHandle.PrepareContext(ctx, q)
  215. if err != nil {
  216. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  217. return err
  218. }
  219. defer stmt.Close()
  220. _, err = stmt.ExecContext(ctx, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope, util.GetTimeAsMsSinceEpoch(time.Now()),
  221. util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt, apiKey.ExpiresAt, apiKey.Description,
  222. userID, adminID)
  223. return err
  224. }
  225. func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  226. err := apiKey.validate()
  227. if err != nil {
  228. return err
  229. }
  230. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  231. if err != nil {
  232. return err
  233. }
  234. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  235. defer cancel()
  236. q := getUpdateAPIKeyQuery()
  237. stmt, err := dbHandle.PrepareContext(ctx, q)
  238. if err != nil {
  239. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  240. return err
  241. }
  242. defer stmt.Close()
  243. _, err = stmt.ExecContext(ctx, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
  244. apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
  245. return err
  246. }
  247. func sqlCommonDeleteAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  248. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  249. defer cancel()
  250. q := getDeleteAPIKeyQuery()
  251. stmt, err := dbHandle.PrepareContext(ctx, q)
  252. if err != nil {
  253. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  254. return err
  255. }
  256. defer stmt.Close()
  257. _, err = stmt.ExecContext(ctx, apiKey.KeyID)
  258. return err
  259. }
  260. func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
  261. apiKeys := make([]APIKey, 0, limit)
  262. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  263. defer cancel()
  264. q := getAPIKeysQuery(order)
  265. stmt, err := dbHandle.PrepareContext(ctx, q)
  266. if err != nil {
  267. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  268. return nil, err
  269. }
  270. defer stmt.Close()
  271. rows, err := stmt.QueryContext(ctx, limit, offset)
  272. if err != nil {
  273. return apiKeys, err
  274. }
  275. defer rows.Close()
  276. for rows.Next() {
  277. k, err := getAPIKeyFromDbRow(rows)
  278. if err != nil {
  279. return apiKeys, err
  280. }
  281. k.HideConfidentialData()
  282. apiKeys = append(apiKeys, k)
  283. }
  284. err = rows.Err()
  285. if err != nil {
  286. return apiKeys, err
  287. }
  288. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  289. if err != nil {
  290. return apiKeys, err
  291. }
  292. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  293. }
  294. func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
  295. apiKeys := make([]APIKey, 0, 30)
  296. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  297. defer cancel()
  298. q := getDumpAPIKeysQuery()
  299. stmt, err := dbHandle.PrepareContext(ctx, q)
  300. if err != nil {
  301. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  302. return nil, err
  303. }
  304. defer stmt.Close()
  305. rows, err := stmt.QueryContext(ctx)
  306. if err != nil {
  307. return apiKeys, err
  308. }
  309. defer rows.Close()
  310. for rows.Next() {
  311. k, err := getAPIKeyFromDbRow(rows)
  312. if err != nil {
  313. return apiKeys, err
  314. }
  315. apiKeys = append(apiKeys, k)
  316. }
  317. err = rows.Err()
  318. if err != nil {
  319. return apiKeys, err
  320. }
  321. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  322. if err != nil {
  323. return apiKeys, err
  324. }
  325. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  326. }
  327. func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
  328. var admin Admin
  329. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  330. defer cancel()
  331. q := getAdminByUsernameQuery()
  332. stmt, err := dbHandle.PrepareContext(ctx, q)
  333. if err != nil {
  334. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  335. return admin, err
  336. }
  337. defer stmt.Close()
  338. row := stmt.QueryRowContext(ctx, username)
  339. return getAdminFromDbRow(row)
  340. }
  341. func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) {
  342. admin, err := sqlCommonGetAdminByUsername(username, dbHandle)
  343. if err != nil {
  344. providerLog(logger.LevelWarn, "error authenticating admin %#v: %v", username, err)
  345. return admin, ErrInvalidCredentials
  346. }
  347. err = admin.checkUserAndPass(password, ip)
  348. return admin, err
  349. }
  350. func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
  351. err := admin.validate()
  352. if err != nil {
  353. return err
  354. }
  355. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  356. defer cancel()
  357. q := getAddAdminQuery()
  358. stmt, err := dbHandle.PrepareContext(ctx, q)
  359. if err != nil {
  360. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  361. return err
  362. }
  363. defer stmt.Close()
  364. perms, err := json.Marshal(admin.Permissions)
  365. if err != nil {
  366. return err
  367. }
  368. filters, err := json.Marshal(admin.Filters)
  369. if err != nil {
  370. return err
  371. }
  372. _, err = stmt.ExecContext(ctx, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
  373. string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  374. util.GetTimeAsMsSinceEpoch(time.Now()))
  375. return err
  376. }
  377. func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
  378. err := admin.validate()
  379. if err != nil {
  380. return err
  381. }
  382. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  383. defer cancel()
  384. q := getUpdateAdminQuery()
  385. stmt, err := dbHandle.PrepareContext(ctx, q)
  386. if err != nil {
  387. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  388. return err
  389. }
  390. defer stmt.Close()
  391. perms, err := json.Marshal(admin.Permissions)
  392. if err != nil {
  393. return err
  394. }
  395. filters, err := json.Marshal(admin.Filters)
  396. if err != nil {
  397. return err
  398. }
  399. _, err = stmt.ExecContext(ctx, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
  400. admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Username)
  401. return err
  402. }
  403. func sqlCommonDeleteAdmin(admin *Admin, dbHandle *sql.DB) error {
  404. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  405. defer cancel()
  406. q := getDeleteAdminQuery()
  407. stmt, err := dbHandle.PrepareContext(ctx, q)
  408. if err != nil {
  409. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  410. return err
  411. }
  412. defer stmt.Close()
  413. _, err = stmt.ExecContext(ctx, admin.Username)
  414. return err
  415. }
  416. func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) {
  417. admins := make([]Admin, 0, limit)
  418. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  419. defer cancel()
  420. q := getAdminsQuery(order)
  421. stmt, err := dbHandle.PrepareContext(ctx, q)
  422. if err != nil {
  423. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  424. return nil, err
  425. }
  426. defer stmt.Close()
  427. rows, err := stmt.QueryContext(ctx, limit, offset)
  428. if err != nil {
  429. return admins, err
  430. }
  431. defer rows.Close()
  432. for rows.Next() {
  433. a, err := getAdminFromDbRow(rows)
  434. if err != nil {
  435. return admins, err
  436. }
  437. a.HideConfidentialData()
  438. admins = append(admins, a)
  439. }
  440. return admins, rows.Err()
  441. }
  442. func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
  443. admins := make([]Admin, 0, 30)
  444. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  445. defer cancel()
  446. q := getDumpAdminsQuery()
  447. stmt, err := dbHandle.PrepareContext(ctx, q)
  448. if err != nil {
  449. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  450. return nil, err
  451. }
  452. defer stmt.Close()
  453. rows, err := stmt.QueryContext(ctx)
  454. if err != nil {
  455. return admins, err
  456. }
  457. defer rows.Close()
  458. for rows.Next() {
  459. a, err := getAdminFromDbRow(rows)
  460. if err != nil {
  461. return admins, err
  462. }
  463. admins = append(admins, a)
  464. }
  465. return admins, rows.Err()
  466. }
  467. func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
  468. var user User
  469. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  470. defer cancel()
  471. q := getUserByUsernameQuery()
  472. stmt, err := dbHandle.PrepareContext(ctx, q)
  473. if err != nil {
  474. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  475. return user, err
  476. }
  477. defer stmt.Close()
  478. row := stmt.QueryRowContext(ctx, username)
  479. user, err = getUserFromDbRow(row)
  480. if err != nil {
  481. return user, err
  482. }
  483. return getUserWithVirtualFolders(ctx, user, dbHandle)
  484. }
  485. func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
  486. var user User
  487. if password == "" {
  488. return user, errors.New("credentials cannot be null or empty")
  489. }
  490. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  491. if err != nil {
  492. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  493. return user, err
  494. }
  495. return checkUserAndPass(&user, password, ip, protocol)
  496. }
  497. func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) {
  498. var user User
  499. if tlsCert == nil {
  500. return user, errors.New("TLS certificate cannot be null or empty")
  501. }
  502. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  503. if err != nil {
  504. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  505. return user, err
  506. }
  507. return checkUserAndTLSCertificate(&user, protocol, tlsCert)
  508. }
  509. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
  510. var user User
  511. if len(pubKey) == 0 {
  512. return user, "", errors.New("credentials cannot be null or empty")
  513. }
  514. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  515. if err != nil {
  516. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  517. return user, "", err
  518. }
  519. return checkUserAndPubKey(&user, pubKey)
  520. }
  521. func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
  522. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  523. defer cancel()
  524. return dbHandle.PingContext(ctx)
  525. }
  526. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  527. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  528. defer cancel()
  529. q := getUpdateQuotaQuery(reset)
  530. stmt, err := dbHandle.PrepareContext(ctx, q)
  531. if err != nil {
  532. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  533. return err
  534. }
  535. defer stmt.Close()
  536. _, err = stmt.ExecContext(ctx, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  537. if err == nil {
  538. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  539. username, filesAdd, sizeAdd, reset)
  540. } else {
  541. providerLog(logger.LevelWarn, "error updating quota for user %#v: %v", username, err)
  542. }
  543. return err
  544. }
  545. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
  546. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  547. defer cancel()
  548. q := getQuotaQuery()
  549. stmt, err := dbHandle.PrepareContext(ctx, q)
  550. if err != nil {
  551. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  552. return 0, 0, err
  553. }
  554. defer stmt.Close()
  555. var usedFiles int
  556. var usedSize int64
  557. err = stmt.QueryRowContext(ctx, username).Scan(&usedSize, &usedFiles)
  558. if err != nil {
  559. providerLog(logger.LevelWarn, "error getting quota for user: %v, error: %v", username, err)
  560. return 0, 0, err
  561. }
  562. return usedFiles, usedSize, err
  563. }
  564. func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error {
  565. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  566. defer cancel()
  567. q := getUpdateShareLastUseQuery()
  568. stmt, err := dbHandle.PrepareContext(ctx, q)
  569. if err != nil {
  570. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  571. return err
  572. }
  573. defer stmt.Close()
  574. _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
  575. if err == nil {
  576. providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
  577. } else {
  578. providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
  579. }
  580. return err
  581. }
  582. func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
  583. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  584. defer cancel()
  585. q := getUpdateAPIKeyLastUseQuery()
  586. stmt, err := dbHandle.PrepareContext(ctx, q)
  587. if err != nil {
  588. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  589. return err
  590. }
  591. defer stmt.Close()
  592. _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
  593. if err == nil {
  594. providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
  595. } else {
  596. providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
  597. }
  598. return err
  599. }
  600. func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
  601. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  602. defer cancel()
  603. q := getUpdateAdminLastLoginQuery()
  604. stmt, err := dbHandle.PrepareContext(ctx, q)
  605. if err != nil {
  606. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  607. return err
  608. }
  609. defer stmt.Close()
  610. _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  611. if err == nil {
  612. providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
  613. } else {
  614. providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
  615. }
  616. return err
  617. }
  618. func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
  619. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  620. defer cancel()
  621. q := getSetUpdateAtQuery()
  622. stmt, err := dbHandle.PrepareContext(ctx, q)
  623. if err != nil {
  624. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  625. return
  626. }
  627. defer stmt.Close()
  628. _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  629. if err == nil {
  630. providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
  631. } else {
  632. providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
  633. }
  634. }
  635. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  636. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  637. defer cancel()
  638. q := getUpdateLastLoginQuery()
  639. stmt, err := dbHandle.PrepareContext(ctx, q)
  640. if err != nil {
  641. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  642. return err
  643. }
  644. defer stmt.Close()
  645. _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  646. if err == nil {
  647. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  648. } else {
  649. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  650. }
  651. return err
  652. }
  653. func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
  654. err := ValidateUser(user)
  655. if err != nil {
  656. return err
  657. }
  658. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  659. defer cancel()
  660. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  661. q := getAddUserQuery()
  662. stmt, err := tx.PrepareContext(ctx, q)
  663. if err != nil {
  664. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  665. return err
  666. }
  667. defer stmt.Close()
  668. permissions, err := user.GetPermissionsAsJSON()
  669. if err != nil {
  670. return err
  671. }
  672. publicKeys, err := user.GetPublicKeysAsJSON()
  673. if err != nil {
  674. return err
  675. }
  676. filters, err := user.GetFiltersAsJSON()
  677. if err != nil {
  678. return err
  679. }
  680. fsConfig, err := user.GetFsConfigAsJSON()
  681. if err != nil {
  682. return err
  683. }
  684. _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  685. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
  686. string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()),
  687. util.GetTimeAsMsSinceEpoch(time.Now()))
  688. if err != nil {
  689. return err
  690. }
  691. return generateVirtualFoldersMapping(ctx, user, tx)
  692. })
  693. }
  694. func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
  695. err := ValidateUser(user)
  696. if err != nil {
  697. return err
  698. }
  699. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  700. defer cancel()
  701. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  702. q := getUpdateUserQuery()
  703. stmt, err := tx.PrepareContext(ctx, q)
  704. if err != nil {
  705. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  706. return err
  707. }
  708. defer stmt.Close()
  709. permissions, err := user.GetPermissionsAsJSON()
  710. if err != nil {
  711. return err
  712. }
  713. publicKeys, err := user.GetPublicKeysAsJSON()
  714. if err != nil {
  715. return err
  716. }
  717. filters, err := user.GetFiltersAsJSON()
  718. if err != nil {
  719. return err
  720. }
  721. fsConfig, err := user.GetFsConfigAsJSON()
  722. if err != nil {
  723. return err
  724. }
  725. _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
  726. user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
  727. string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()),
  728. user.ID)
  729. if err != nil {
  730. return err
  731. }
  732. return generateVirtualFoldersMapping(ctx, user, tx)
  733. })
  734. }
  735. func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
  736. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  737. defer cancel()
  738. q := getDeleteUserQuery()
  739. stmt, err := dbHandle.PrepareContext(ctx, q)
  740. if err != nil {
  741. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  742. return err
  743. }
  744. defer stmt.Close()
  745. _, err = stmt.ExecContext(ctx, user.ID)
  746. return err
  747. }
  748. func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
  749. users := make([]User, 0, 100)
  750. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  751. defer cancel()
  752. q := getDumpUsersQuery()
  753. stmt, err := dbHandle.PrepareContext(ctx, q)
  754. if err != nil {
  755. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  756. return nil, err
  757. }
  758. defer stmt.Close()
  759. rows, err := stmt.QueryContext(ctx)
  760. if err != nil {
  761. return users, err
  762. }
  763. defer rows.Close()
  764. for rows.Next() {
  765. u, err := getUserFromDbRow(rows)
  766. if err != nil {
  767. return users, err
  768. }
  769. err = addCredentialsToUser(&u)
  770. if err != nil {
  771. return users, err
  772. }
  773. users = append(users, u)
  774. }
  775. err = rows.Err()
  776. if err != nil {
  777. return users, err
  778. }
  779. return getUsersWithVirtualFolders(ctx, users, dbHandle)
  780. }
  781. func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) {
  782. users := make([]User, 0, 10)
  783. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  784. defer cancel()
  785. q := getRecentlyUpdatedUsersQuery()
  786. stmt, err := dbHandle.PrepareContext(ctx, q)
  787. if err != nil {
  788. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  789. return nil, err
  790. }
  791. defer stmt.Close()
  792. rows, err := stmt.QueryContext(ctx, after)
  793. if err == nil {
  794. defer rows.Close()
  795. for rows.Next() {
  796. u, err := getUserFromDbRow(rows)
  797. if err != nil {
  798. return users, err
  799. }
  800. users = append(users, u)
  801. }
  802. }
  803. err = rows.Err()
  804. if err != nil {
  805. return users, err
  806. }
  807. return getUsersWithVirtualFolders(ctx, users, dbHandle)
  808. }
  809. func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
  810. users := make([]User, 0, limit)
  811. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  812. defer cancel()
  813. q := getUsersQuery(order)
  814. stmt, err := dbHandle.PrepareContext(ctx, q)
  815. if err != nil {
  816. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  817. return nil, err
  818. }
  819. defer stmt.Close()
  820. rows, err := stmt.QueryContext(ctx, limit, offset)
  821. if err == nil {
  822. defer rows.Close()
  823. for rows.Next() {
  824. u, err := getUserFromDbRow(rows)
  825. if err != nil {
  826. return users, err
  827. }
  828. u.PrepareForRendering()
  829. users = append(users, u)
  830. }
  831. }
  832. err = rows.Err()
  833. if err != nil {
  834. return users, err
  835. }
  836. return getUsersWithVirtualFolders(ctx, users, dbHandle)
  837. }
  838. func getShareFromDbRow(row sqlScanner) (Share, error) {
  839. var share Share
  840. var description, password, allowFrom, paths sql.NullString
  841. err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope,
  842. &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt,
  843. &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens,
  844. &share.UsedTokens, &allowFrom)
  845. if err != nil {
  846. if errors.Is(err, sql.ErrNoRows) {
  847. return share, util.NewRecordNotFoundError(err.Error())
  848. }
  849. return share, err
  850. }
  851. if paths.Valid {
  852. var list []string
  853. err = json.Unmarshal([]byte(paths.String), &list)
  854. if err != nil {
  855. return share, err
  856. }
  857. share.Paths = list
  858. } else {
  859. return share, errors.New("unable to decode shared paths")
  860. }
  861. if description.Valid {
  862. share.Description = description.String
  863. }
  864. if password.Valid {
  865. share.Password = password.String
  866. }
  867. if allowFrom.Valid {
  868. var list []string
  869. err = json.Unmarshal([]byte(allowFrom.String), &list)
  870. if err == nil {
  871. share.AllowFrom = list
  872. }
  873. }
  874. return share, nil
  875. }
  876. func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
  877. var apiKey APIKey
  878. var userID, adminID sql.NullInt64
  879. var description sql.NullString
  880. err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
  881. &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
  882. if err != nil {
  883. if errors.Is(err, sql.ErrNoRows) {
  884. return apiKey, util.NewRecordNotFoundError(err.Error())
  885. }
  886. return apiKey, err
  887. }
  888. if userID.Valid {
  889. apiKey.userID = userID.Int64
  890. }
  891. if adminID.Valid {
  892. apiKey.adminID = adminID.Int64
  893. }
  894. if description.Valid {
  895. apiKey.Description = description.String
  896. }
  897. return apiKey, nil
  898. }
  899. func getAdminFromDbRow(row sqlScanner) (Admin, error) {
  900. var admin Admin
  901. var email, filters, additionalInfo, permissions, description sql.NullString
  902. err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions,
  903. &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin)
  904. if err != nil {
  905. if errors.Is(err, sql.ErrNoRows) {
  906. return admin, util.NewRecordNotFoundError(err.Error())
  907. }
  908. return admin, err
  909. }
  910. if permissions.Valid {
  911. var perms []string
  912. err = json.Unmarshal([]byte(permissions.String), &perms)
  913. if err != nil {
  914. return admin, err
  915. }
  916. admin.Permissions = perms
  917. }
  918. if email.Valid {
  919. admin.Email = email.String
  920. }
  921. if filters.Valid {
  922. var adminFilters AdminFilters
  923. err = json.Unmarshal([]byte(filters.String), &adminFilters)
  924. if err == nil {
  925. admin.Filters = adminFilters
  926. }
  927. }
  928. if additionalInfo.Valid {
  929. admin.AdditionalInfo = additionalInfo.String
  930. }
  931. if description.Valid {
  932. admin.Description = description.String
  933. }
  934. admin.SetEmptySecretsIfNil()
  935. return admin, nil
  936. }
  937. func getUserFromDbRow(row sqlScanner) (User, error) {
  938. var user User
  939. var permissions sql.NullString
  940. var password sql.NullString
  941. var publicKey sql.NullString
  942. var filters sql.NullString
  943. var fsConfig sql.NullString
  944. var additionalInfo, description, email sql.NullString
  945. err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  946. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  947. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  948. &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt)
  949. if err != nil {
  950. if errors.Is(err, sql.ErrNoRows) {
  951. return user, util.NewRecordNotFoundError(err.Error())
  952. }
  953. return user, err
  954. }
  955. if password.Valid {
  956. user.Password = password.String
  957. }
  958. // we can have a empty string or an invalid json in null string
  959. // so we do a relaxed test if the field is optional, for example we
  960. // populate public keys only if unmarshal does not return an error
  961. if publicKey.Valid {
  962. var list []string
  963. err = json.Unmarshal([]byte(publicKey.String), &list)
  964. if err == nil {
  965. user.PublicKeys = list
  966. }
  967. }
  968. if permissions.Valid {
  969. perms := make(map[string][]string)
  970. err = json.Unmarshal([]byte(permissions.String), &perms)
  971. if err != nil {
  972. providerLog(logger.LevelWarn, "unable to deserialize permissions for user %#v: %v", user.Username, err)
  973. return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
  974. }
  975. user.Permissions = perms
  976. }
  977. if filters.Valid {
  978. var userFilters sdk.UserFilters
  979. err = json.Unmarshal([]byte(filters.String), &userFilters)
  980. if err == nil {
  981. user.Filters = userFilters
  982. }
  983. }
  984. if fsConfig.Valid {
  985. var fs vfs.Filesystem
  986. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  987. if err == nil {
  988. user.FsConfig = fs
  989. }
  990. }
  991. if additionalInfo.Valid {
  992. user.AdditionalInfo = additionalInfo.String
  993. }
  994. if description.Valid {
  995. user.Description = description.String
  996. }
  997. if email.Valid {
  998. user.Email = email.String
  999. }
  1000. user.SetEmptySecretsIfNil()
  1001. return user, nil
  1002. }
  1003. func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
  1004. var folderName string
  1005. q := checkFolderNameQuery()
  1006. stmt, err := dbHandle.PrepareContext(ctx, q)
  1007. if err != nil {
  1008. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1009. return err
  1010. }
  1011. defer stmt.Close()
  1012. row := stmt.QueryRowContext(ctx, name)
  1013. return row.Scan(&folderName)
  1014. }
  1015. func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1016. var folder vfs.BaseVirtualFolder
  1017. q := getFolderByNameQuery()
  1018. stmt, err := dbHandle.PrepareContext(ctx, q)
  1019. if err != nil {
  1020. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1021. return folder, err
  1022. }
  1023. defer stmt.Close()
  1024. row := stmt.QueryRowContext(ctx, name)
  1025. var mappedPath, description, fsConfig sql.NullString
  1026. err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
  1027. &folder.Name, &description, &fsConfig)
  1028. if err != nil {
  1029. if errors.Is(err, sql.ErrNoRows) {
  1030. return folder, util.NewRecordNotFoundError(err.Error())
  1031. }
  1032. return folder, err
  1033. }
  1034. if mappedPath.Valid {
  1035. folder.MappedPath = mappedPath.String
  1036. }
  1037. if description.Valid {
  1038. folder.Description = description.String
  1039. }
  1040. if fsConfig.Valid {
  1041. var fs vfs.Filesystem
  1042. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1043. if err == nil {
  1044. folder.FsConfig = fs
  1045. }
  1046. }
  1047. return folder, err
  1048. }
  1049. func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1050. folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
  1051. if err != nil {
  1052. return folder, err
  1053. }
  1054. folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle)
  1055. if err != nil {
  1056. return folder, err
  1057. }
  1058. if len(folders) != 1 {
  1059. return folder, fmt.Errorf("unable to associate users with folder %#v", name)
  1060. }
  1061. return folders[0], nil
  1062. }
  1063. func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
  1064. usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1065. var folder vfs.BaseVirtualFolder
  1066. // FIXME: we could use an UPSERT here, this SELECT could be racy
  1067. err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
  1068. switch err {
  1069. case nil:
  1070. err = sqlCommonUpdateFolder(baseFolder, dbHandle)
  1071. if err != nil {
  1072. return folder, err
  1073. }
  1074. case sql.ErrNoRows:
  1075. baseFolder.UsedQuotaFiles = usedQuotaFiles
  1076. baseFolder.UsedQuotaSize = usedQuotaSize
  1077. baseFolder.LastQuotaUpdate = lastQuotaUpdate
  1078. err = sqlCommonAddFolder(baseFolder, dbHandle)
  1079. if err != nil {
  1080. return folder, err
  1081. }
  1082. default:
  1083. return folder, err
  1084. }
  1085. return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle)
  1086. }
  1087. func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1088. err := ValidateFolder(folder)
  1089. if err != nil {
  1090. return err
  1091. }
  1092. fsConfig, err := json.Marshal(folder.FsConfig)
  1093. if err != nil {
  1094. return err
  1095. }
  1096. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1097. defer cancel()
  1098. q := getAddFolderQuery()
  1099. stmt, err := dbHandle.PrepareContext(ctx, q)
  1100. if err != nil {
  1101. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1102. return err
  1103. }
  1104. defer stmt.Close()
  1105. _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
  1106. folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
  1107. return err
  1108. }
  1109. func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1110. err := ValidateFolder(folder)
  1111. if err != nil {
  1112. return err
  1113. }
  1114. fsConfig, err := json.Marshal(folder.FsConfig)
  1115. if err != nil {
  1116. return err
  1117. }
  1118. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1119. defer cancel()
  1120. q := getUpdateFolderQuery()
  1121. stmt, err := dbHandle.PrepareContext(ctx, q)
  1122. if err != nil {
  1123. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1124. return err
  1125. }
  1126. defer stmt.Close()
  1127. _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
  1128. return err
  1129. }
  1130. func sqlCommonDeleteFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1131. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1132. defer cancel()
  1133. q := getDeleteFolderQuery()
  1134. stmt, err := dbHandle.PrepareContext(ctx, q)
  1135. if err != nil {
  1136. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1137. return err
  1138. }
  1139. defer stmt.Close()
  1140. _, err = stmt.ExecContext(ctx, folder.ID)
  1141. return err
  1142. }
  1143. func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1144. folders := make([]vfs.BaseVirtualFolder, 0, 50)
  1145. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1146. defer cancel()
  1147. q := getDumpFoldersQuery()
  1148. stmt, err := dbHandle.PrepareContext(ctx, q)
  1149. if err != nil {
  1150. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1151. return nil, err
  1152. }
  1153. defer stmt.Close()
  1154. rows, err := stmt.QueryContext(ctx)
  1155. if err != nil {
  1156. return folders, err
  1157. }
  1158. defer rows.Close()
  1159. for rows.Next() {
  1160. var folder vfs.BaseVirtualFolder
  1161. var mappedPath, description, fsConfig sql.NullString
  1162. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1163. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1164. if err != nil {
  1165. return folders, err
  1166. }
  1167. if mappedPath.Valid {
  1168. folder.MappedPath = mappedPath.String
  1169. }
  1170. if description.Valid {
  1171. folder.Description = description.String
  1172. }
  1173. if fsConfig.Valid {
  1174. var fs vfs.Filesystem
  1175. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1176. if err == nil {
  1177. folder.FsConfig = fs
  1178. }
  1179. }
  1180. folders = append(folders, folder)
  1181. }
  1182. err = rows.Err()
  1183. if err != nil {
  1184. return folders, err
  1185. }
  1186. return getVirtualFoldersWithUsers(folders, dbHandle)
  1187. }
  1188. func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1189. folders := make([]vfs.BaseVirtualFolder, 0, limit)
  1190. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1191. defer cancel()
  1192. q := getFoldersQuery(order)
  1193. stmt, err := dbHandle.PrepareContext(ctx, q)
  1194. if err != nil {
  1195. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1196. return nil, err
  1197. }
  1198. defer stmt.Close()
  1199. rows, err := stmt.QueryContext(ctx, limit, offset)
  1200. if err != nil {
  1201. return folders, err
  1202. }
  1203. defer rows.Close()
  1204. for rows.Next() {
  1205. var folder vfs.BaseVirtualFolder
  1206. var mappedPath, description, fsConfig sql.NullString
  1207. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1208. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1209. if err != nil {
  1210. return folders, err
  1211. }
  1212. if mappedPath.Valid {
  1213. folder.MappedPath = mappedPath.String
  1214. }
  1215. if description.Valid {
  1216. folder.Description = description.String
  1217. }
  1218. if fsConfig.Valid {
  1219. var fs vfs.Filesystem
  1220. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1221. if err == nil {
  1222. folder.FsConfig = fs
  1223. }
  1224. }
  1225. folder.PrepareForRendering()
  1226. folders = append(folders, folder)
  1227. }
  1228. err = rows.Err()
  1229. if err != nil {
  1230. return folders, err
  1231. }
  1232. return getVirtualFoldersWithUsers(folders, dbHandle)
  1233. }
  1234. func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1235. q := getClearFolderMappingQuery()
  1236. stmt, err := dbHandle.PrepareContext(ctx, q)
  1237. if err != nil {
  1238. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1239. return err
  1240. }
  1241. defer stmt.Close()
  1242. _, err = stmt.ExecContext(ctx, user.Username)
  1243. return err
  1244. }
  1245. func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1246. q := getAddFolderMappingQuery()
  1247. stmt, err := dbHandle.PrepareContext(ctx, q)
  1248. if err != nil {
  1249. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1250. return err
  1251. }
  1252. defer stmt.Close()
  1253. _, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.ID, user.Username)
  1254. return err
  1255. }
  1256. func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1257. err := sqlCommonClearFolderMapping(ctx, user, dbHandle)
  1258. if err != nil {
  1259. return err
  1260. }
  1261. for idx := range user.VirtualFolders {
  1262. vfolder := &user.VirtualFolders[idx]
  1263. f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1264. if err != nil {
  1265. return err
  1266. }
  1267. vfolder.BaseVirtualFolder = f
  1268. err = sqlCommonAddFolderMapping(ctx, user, vfolder, dbHandle)
  1269. if err != nil {
  1270. return err
  1271. }
  1272. }
  1273. return err
  1274. }
  1275. func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  1276. users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
  1277. if err != nil {
  1278. return user, err
  1279. }
  1280. if len(users) == 0 {
  1281. return user, errSQLFoldersAssosaction
  1282. }
  1283. return users[0], err
  1284. }
  1285. func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  1286. if len(users) == 0 {
  1287. return users, nil
  1288. }
  1289. var err error
  1290. usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  1291. q := getRelatedFoldersForUsersQuery(users)
  1292. stmt, err := dbHandle.PrepareContext(ctx, q)
  1293. if err != nil {
  1294. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1295. return nil, err
  1296. }
  1297. defer stmt.Close()
  1298. rows, err := stmt.QueryContext(ctx)
  1299. if err != nil {
  1300. return nil, err
  1301. }
  1302. defer rows.Close()
  1303. for rows.Next() {
  1304. var folder vfs.VirtualFolder
  1305. var userID int64
  1306. var mappedPath, fsConfig, description sql.NullString
  1307. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1308. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
  1309. &description)
  1310. if err != nil {
  1311. return users, err
  1312. }
  1313. if mappedPath.Valid {
  1314. folder.MappedPath = mappedPath.String
  1315. }
  1316. if description.Valid {
  1317. folder.Description = description.String
  1318. }
  1319. if fsConfig.Valid {
  1320. var fs vfs.Filesystem
  1321. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1322. if err == nil {
  1323. folder.FsConfig = fs
  1324. }
  1325. }
  1326. usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
  1327. }
  1328. err = rows.Err()
  1329. if err != nil {
  1330. return users, err
  1331. }
  1332. if len(usersVirtualFolders) == 0 {
  1333. return users, err
  1334. }
  1335. for idx := range users {
  1336. ref := &users[idx]
  1337. ref.VirtualFolders = usersVirtualFolders[ref.ID]
  1338. }
  1339. return users, err
  1340. }
  1341. func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1342. if len(folders) == 0 {
  1343. return folders, nil
  1344. }
  1345. var err error
  1346. vFoldersUsers := make(map[int64][]string)
  1347. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1348. defer cancel()
  1349. q := getRelatedUsersForFoldersQuery(folders)
  1350. stmt, err := dbHandle.PrepareContext(ctx, q)
  1351. if err != nil {
  1352. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1353. return nil, err
  1354. }
  1355. defer stmt.Close()
  1356. rows, err := stmt.QueryContext(ctx)
  1357. if err != nil {
  1358. return nil, err
  1359. }
  1360. defer rows.Close()
  1361. for rows.Next() {
  1362. var username string
  1363. var folderID int64
  1364. err = rows.Scan(&folderID, &username)
  1365. if err != nil {
  1366. return folders, err
  1367. }
  1368. vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username)
  1369. }
  1370. err = rows.Err()
  1371. if err != nil {
  1372. return folders, err
  1373. }
  1374. if len(vFoldersUsers) == 0 {
  1375. return folders, err
  1376. }
  1377. for idx := range folders {
  1378. ref := &folders[idx]
  1379. ref.Users = vFoldersUsers[ref.ID]
  1380. }
  1381. return folders, err
  1382. }
  1383. func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  1384. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1385. defer cancel()
  1386. q := getUpdateFolderQuotaQuery(reset)
  1387. stmt, err := dbHandle.PrepareContext(ctx, q)
  1388. if err != nil {
  1389. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1390. return err
  1391. }
  1392. defer stmt.Close()
  1393. _, err = stmt.ExecContext(ctx, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  1394. if err == nil {
  1395. providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
  1396. name, filesAdd, sizeAdd, reset)
  1397. } else {
  1398. providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
  1399. }
  1400. return err
  1401. }
  1402. func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) {
  1403. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1404. defer cancel()
  1405. q := getQuotaFolderQuery()
  1406. stmt, err := dbHandle.PrepareContext(ctx, q)
  1407. if err != nil {
  1408. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1409. return 0, 0, err
  1410. }
  1411. defer stmt.Close()
  1412. var usedFiles int
  1413. var usedSize int64
  1414. err = stmt.QueryRowContext(ctx, mappedPath).Scan(&usedSize, &usedFiles)
  1415. if err != nil {
  1416. providerLog(logger.LevelWarn, "error getting quota for folder: %v, error: %v", mappedPath, err)
  1417. return 0, 0, err
  1418. }
  1419. return usedFiles, usedSize, err
  1420. }
  1421. func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
  1422. var apiKeys []APIKey
  1423. var err error
  1424. scope := APIKeyScopeAdmin
  1425. if apiKey.userID > 0 {
  1426. scope = APIKeyScopeUser
  1427. }
  1428. apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
  1429. if err != nil {
  1430. return apiKey, err
  1431. }
  1432. if len(apiKeys) > 0 {
  1433. apiKey = apiKeys[0]
  1434. }
  1435. return apiKey, nil
  1436. }
  1437. func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
  1438. if len(apiKeys) == 0 {
  1439. return apiKeys, nil
  1440. }
  1441. values := make(map[int64]string)
  1442. var q string
  1443. if scope == APIKeyScopeUser {
  1444. q = getRelatedUsersForAPIKeysQuery(apiKeys)
  1445. } else {
  1446. q = getRelatedAdminsForAPIKeysQuery(apiKeys)
  1447. }
  1448. stmt, err := dbHandle.PrepareContext(ctx, q)
  1449. if err != nil {
  1450. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1451. return nil, err
  1452. }
  1453. defer stmt.Close()
  1454. rows, err := stmt.QueryContext(ctx)
  1455. if err != nil {
  1456. return nil, err
  1457. }
  1458. defer rows.Close()
  1459. for rows.Next() {
  1460. var valueID int64
  1461. var valueName string
  1462. err = rows.Scan(&valueID, &valueName)
  1463. if err != nil {
  1464. return apiKeys, err
  1465. }
  1466. values[valueID] = valueName
  1467. }
  1468. err = rows.Err()
  1469. if err != nil {
  1470. return apiKeys, err
  1471. }
  1472. if len(values) == 0 {
  1473. return apiKeys, nil
  1474. }
  1475. for idx := range apiKeys {
  1476. ref := &apiKeys[idx]
  1477. if scope == APIKeyScopeUser {
  1478. ref.User = values[ref.userID]
  1479. } else {
  1480. ref.Admin = values[ref.adminID]
  1481. }
  1482. }
  1483. return apiKeys, nil
  1484. }
  1485. func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
  1486. var userID, adminID sql.NullInt64
  1487. if apiKey.User != "" {
  1488. u, err := provider.userExists(apiKey.User)
  1489. if err != nil {
  1490. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
  1491. }
  1492. userID.Valid = true
  1493. userID.Int64 = u.ID
  1494. }
  1495. if apiKey.Admin != "" {
  1496. a, err := provider.adminExists(apiKey.Admin)
  1497. if err != nil {
  1498. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
  1499. }
  1500. adminID.Valid = true
  1501. adminID.Int64 = a.ID
  1502. }
  1503. return userID, adminID, nil
  1504. }
  1505. func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
  1506. var result schemaVersion
  1507. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1508. defer cancel()
  1509. q := getDatabaseVersionQuery()
  1510. stmt, err := dbHandle.PrepareContext(ctx, q)
  1511. if err != nil {
  1512. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1513. if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
  1514. logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
  1515. }
  1516. return result, err
  1517. }
  1518. defer stmt.Close()
  1519. row := stmt.QueryRowContext(ctx)
  1520. err = row.Scan(&result.Version)
  1521. return result, err
  1522. }
  1523. func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
  1524. q := getUpdateDBVersionQuery()
  1525. stmt, err := dbHandle.PrepareContext(ctx, q)
  1526. if err != nil {
  1527. providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
  1528. return err
  1529. }
  1530. defer stmt.Close()
  1531. _, err = stmt.ExecContext(ctx, version)
  1532. return err
  1533. }
  1534. func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error {
  1535. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1536. defer cancel()
  1537. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1538. for _, q := range sqlQueries {
  1539. if strings.TrimSpace(q) == "" {
  1540. continue
  1541. }
  1542. _, err := tx.ExecContext(ctx, q)
  1543. if err != nil {
  1544. return err
  1545. }
  1546. }
  1547. if newVersion == 0 {
  1548. return nil
  1549. }
  1550. return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
  1551. })
  1552. }
  1553. func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
  1554. if config.Driver == CockroachDataProviderName {
  1555. return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
  1556. }
  1557. tx, err := dbHandle.BeginTx(ctx, nil)
  1558. if err != nil {
  1559. return err
  1560. }
  1561. err = txFn(tx)
  1562. if err != nil {
  1563. // we don't change the returned error
  1564. tx.Rollback() //nolint:errcheck
  1565. return err
  1566. }
  1567. return tx.Commit()
  1568. }