sqlcommon.go 89 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "context"
  17. "crypto/x509"
  18. "database/sql"
  19. "encoding/json"
  20. "errors"
  21. "fmt"
  22. "runtime/debug"
  23. "strings"
  24. "time"
  25. "github.com/cockroachdb/cockroach-go/v2/crdb"
  26. "github.com/sftpgo/sdk"
  27. "github.com/drakkan/sftpgo/v2/logger"
  28. "github.com/drakkan/sftpgo/v2/util"
  29. "github.com/drakkan/sftpgo/v2/vfs"
  30. )
  31. const (
  32. sqlDatabaseVersion = 20
  33. defaultSQLQueryTimeout = 10 * time.Second
  34. longSQLQueryTimeout = 60 * time.Second
  35. )
  36. var (
  37. errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user")
  38. errSQLGroupsAssociation = errors.New("unable to associate groups to user")
  39. errSQLUsersAssociation = errors.New("unable to associate users to group")
  40. errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
  41. )
  42. type sqlQuerier interface {
  43. QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
  44. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  45. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  46. PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
  47. }
  48. type sqlScanner interface {
  49. Scan(dest ...any) error
  50. }
  51. func sqlReplaceAll(sql string) string {
  52. sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion)
  53. sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
  54. sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
  55. sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
  56. sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
  57. sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
  58. sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
  59. sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
  60. sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
  61. sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
  62. sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
  63. sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
  64. sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
  65. sql = strings.ReplaceAll(sql, "{{shared_sessions}}", sqlTableSharedSessions)
  66. sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions)
  67. sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
  68. sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
  69. sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
  70. sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
  71. return sql
  72. }
  73. func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) {
  74. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  75. defer cancel()
  76. filterUser := username != ""
  77. q := getShareByIDQuery(filterUser)
  78. var row *sql.Row
  79. if filterUser {
  80. row = dbHandle.QueryRowContext(ctx, q, shareID, username)
  81. } else {
  82. row = dbHandle.QueryRowContext(ctx, q, shareID)
  83. }
  84. return getShareFromDbRow(row)
  85. }
  86. func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
  87. err := share.validate()
  88. if err != nil {
  89. return err
  90. }
  91. user, err := provider.userExists(share.Username)
  92. if err != nil {
  93. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  94. }
  95. paths, err := json.Marshal(share.Paths)
  96. if err != nil {
  97. return err
  98. }
  99. allowFrom := ""
  100. if len(share.AllowFrom) > 0 {
  101. res, err := json.Marshal(share.AllowFrom)
  102. if err == nil {
  103. allowFrom = string(res)
  104. }
  105. }
  106. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  107. defer cancel()
  108. q := getAddShareQuery()
  109. usedTokens := 0
  110. createdAt := util.GetTimeAsMsSinceEpoch(time.Now())
  111. updatedAt := createdAt
  112. lastUseAt := int64(0)
  113. if share.IsRestore {
  114. usedTokens = share.UsedTokens
  115. if share.CreatedAt > 0 {
  116. createdAt = share.CreatedAt
  117. }
  118. if share.UpdatedAt > 0 {
  119. updatedAt = share.UpdatedAt
  120. }
  121. lastUseAt = share.LastUseAt
  122. }
  123. _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
  124. string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
  125. share.MaxTokens, usedTokens, allowFrom, user.ID)
  126. return err
  127. }
  128. func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
  129. err := share.validate()
  130. if err != nil {
  131. return err
  132. }
  133. paths, err := json.Marshal(share.Paths)
  134. if err != nil {
  135. return err
  136. }
  137. allowFrom := ""
  138. if len(share.AllowFrom) > 0 {
  139. res, err := json.Marshal(share.AllowFrom)
  140. if err == nil {
  141. allowFrom = string(res)
  142. }
  143. }
  144. user, err := provider.userExists(share.Username)
  145. if err != nil {
  146. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  147. }
  148. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  149. defer cancel()
  150. var q string
  151. if share.IsRestore {
  152. q = getUpdateShareRestoreQuery()
  153. } else {
  154. q = getUpdateShareQuery()
  155. }
  156. if share.IsRestore {
  157. if share.CreatedAt == 0 {
  158. share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  159. }
  160. if share.UpdatedAt == 0 {
  161. share.UpdatedAt = share.CreatedAt
  162. }
  163. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  164. share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
  165. share.UsedTokens, allowFrom, user.ID, share.ShareID)
  166. } else {
  167. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  168. util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
  169. allowFrom, user.ID, share.ShareID)
  170. }
  171. return err
  172. }
  173. func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
  174. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  175. defer cancel()
  176. q := getDeleteShareQuery()
  177. res, err := dbHandle.ExecContext(ctx, q, share.ShareID)
  178. if err != nil {
  179. return err
  180. }
  181. return sqlCommonRequireRowAffected(res)
  182. }
  183. func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) {
  184. shares := make([]Share, 0, limit)
  185. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  186. defer cancel()
  187. q := getSharesQuery(order)
  188. rows, err := dbHandle.QueryContext(ctx, q, username, limit, offset)
  189. if err != nil {
  190. return shares, err
  191. }
  192. defer rows.Close()
  193. for rows.Next() {
  194. s, err := getShareFromDbRow(rows)
  195. if err != nil {
  196. return shares, err
  197. }
  198. s.HideConfidentialData()
  199. shares = append(shares, s)
  200. }
  201. return shares, rows.Err()
  202. }
  203. func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) {
  204. shares := make([]Share, 0, 30)
  205. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  206. defer cancel()
  207. q := getDumpSharesQuery()
  208. rows, err := dbHandle.QueryContext(ctx, q)
  209. if err != nil {
  210. return shares, err
  211. }
  212. defer rows.Close()
  213. for rows.Next() {
  214. s, err := getShareFromDbRow(rows)
  215. if err != nil {
  216. return shares, err
  217. }
  218. shares = append(shares, s)
  219. }
  220. return shares, rows.Err()
  221. }
  222. func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
  223. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  224. defer cancel()
  225. q := getAPIKeyByIDQuery()
  226. row := dbHandle.QueryRowContext(ctx, q, keyID)
  227. apiKey, err := getAPIKeyFromDbRow(row)
  228. if err != nil {
  229. return apiKey, err
  230. }
  231. return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
  232. }
  233. func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  234. err := apiKey.validate()
  235. if err != nil {
  236. return err
  237. }
  238. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  239. if err != nil {
  240. return err
  241. }
  242. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  243. defer cancel()
  244. q := getAddAPIKeyQuery()
  245. _, err = dbHandle.ExecContext(ctx, q, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope,
  246. util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt,
  247. apiKey.ExpiresAt, apiKey.Description, userID, adminID)
  248. return err
  249. }
  250. func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  251. err := apiKey.validate()
  252. if err != nil {
  253. return err
  254. }
  255. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  256. if err != nil {
  257. return err
  258. }
  259. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  260. defer cancel()
  261. q := getUpdateAPIKeyQuery()
  262. _, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
  263. apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
  264. return err
  265. }
  266. func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
  267. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  268. defer cancel()
  269. q := getDeleteAPIKeyQuery()
  270. res, err := dbHandle.ExecContext(ctx, q, apiKey.KeyID)
  271. if err != nil {
  272. return err
  273. }
  274. return sqlCommonRequireRowAffected(res)
  275. }
  276. func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
  277. apiKeys := make([]APIKey, 0, limit)
  278. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  279. defer cancel()
  280. q := getAPIKeysQuery(order)
  281. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  282. if err != nil {
  283. return apiKeys, err
  284. }
  285. defer rows.Close()
  286. for rows.Next() {
  287. k, err := getAPIKeyFromDbRow(rows)
  288. if err != nil {
  289. return apiKeys, err
  290. }
  291. k.HideConfidentialData()
  292. apiKeys = append(apiKeys, k)
  293. }
  294. err = rows.Err()
  295. if err != nil {
  296. return apiKeys, err
  297. }
  298. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  299. if err != nil {
  300. return apiKeys, err
  301. }
  302. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  303. }
  304. func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
  305. apiKeys := make([]APIKey, 0, 30)
  306. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  307. defer cancel()
  308. q := getDumpAPIKeysQuery()
  309. rows, err := dbHandle.QueryContext(ctx, q)
  310. if err != nil {
  311. return apiKeys, err
  312. }
  313. defer rows.Close()
  314. for rows.Next() {
  315. k, err := getAPIKeyFromDbRow(rows)
  316. if err != nil {
  317. return apiKeys, err
  318. }
  319. apiKeys = append(apiKeys, k)
  320. }
  321. err = rows.Err()
  322. if err != nil {
  323. return apiKeys, err
  324. }
  325. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  326. if err != nil {
  327. return apiKeys, err
  328. }
  329. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  330. }
  331. func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
  332. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  333. defer cancel()
  334. q := getAdminByUsernameQuery()
  335. row := dbHandle.QueryRowContext(ctx, q, username)
  336. return getAdminFromDbRow(row)
  337. }
  338. func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) {
  339. admin, err := sqlCommonGetAdminByUsername(username, dbHandle)
  340. if err != nil {
  341. providerLog(logger.LevelWarn, "error authenticating admin %#v: %v", username, err)
  342. return admin, ErrInvalidCredentials
  343. }
  344. err = admin.checkUserAndPass(password, ip)
  345. return admin, err
  346. }
  347. func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
  348. err := admin.validate()
  349. if err != nil {
  350. return err
  351. }
  352. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  353. defer cancel()
  354. q := getAddAdminQuery()
  355. perms, err := json.Marshal(admin.Permissions)
  356. if err != nil {
  357. return err
  358. }
  359. filters, err := json.Marshal(admin.Filters)
  360. if err != nil {
  361. return err
  362. }
  363. _, err = dbHandle.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
  364. string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  365. util.GetTimeAsMsSinceEpoch(time.Now()))
  366. return err
  367. }
  368. func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
  369. err := admin.validate()
  370. if err != nil {
  371. return err
  372. }
  373. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  374. defer cancel()
  375. q := getUpdateAdminQuery()
  376. perms, err := json.Marshal(admin.Permissions)
  377. if err != nil {
  378. return err
  379. }
  380. filters, err := json.Marshal(admin.Filters)
  381. if err != nil {
  382. return err
  383. }
  384. _, err = dbHandle.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
  385. admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Username)
  386. return err
  387. }
  388. func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error {
  389. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  390. defer cancel()
  391. q := getDeleteAdminQuery()
  392. res, err := dbHandle.ExecContext(ctx, q, admin.Username)
  393. if err != nil {
  394. return err
  395. }
  396. return sqlCommonRequireRowAffected(res)
  397. }
  398. func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) {
  399. admins := make([]Admin, 0, limit)
  400. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  401. defer cancel()
  402. q := getAdminsQuery(order)
  403. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  404. if err != nil {
  405. return admins, err
  406. }
  407. defer rows.Close()
  408. for rows.Next() {
  409. a, err := getAdminFromDbRow(rows)
  410. if err != nil {
  411. return admins, err
  412. }
  413. a.HideConfidentialData()
  414. admins = append(admins, a)
  415. }
  416. return admins, rows.Err()
  417. }
  418. func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
  419. admins := make([]Admin, 0, 30)
  420. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  421. defer cancel()
  422. q := getDumpAdminsQuery()
  423. rows, err := dbHandle.QueryContext(ctx, q)
  424. if err != nil {
  425. return admins, err
  426. }
  427. defer rows.Close()
  428. for rows.Next() {
  429. a, err := getAdminFromDbRow(rows)
  430. if err != nil {
  431. return admins, err
  432. }
  433. admins = append(admins, a)
  434. }
  435. return admins, rows.Err()
  436. }
  437. func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) {
  438. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  439. defer cancel()
  440. q := getGroupByNameQuery()
  441. row := dbHandle.QueryRowContext(ctx, q, name)
  442. group, err := getGroupFromDbRow(row)
  443. if err != nil {
  444. return group, err
  445. }
  446. group, err = getGroupWithVirtualFolders(ctx, group, dbHandle)
  447. if err != nil {
  448. return group, err
  449. }
  450. return getGroupWithUsers(ctx, group, dbHandle)
  451. }
  452. func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) {
  453. groups := make([]Group, 0, 50)
  454. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  455. defer cancel()
  456. q := getDumpGroupsQuery()
  457. rows, err := dbHandle.QueryContext(ctx, q)
  458. if err != nil {
  459. return groups, err
  460. }
  461. defer rows.Close()
  462. for rows.Next() {
  463. group, err := getGroupFromDbRow(rows)
  464. if err != nil {
  465. return groups, err
  466. }
  467. groups = append(groups, group)
  468. }
  469. err = rows.Err()
  470. if err != nil {
  471. return groups, err
  472. }
  473. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  474. }
  475. func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) {
  476. if len(names) == 0 {
  477. return nil, nil
  478. }
  479. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  480. defer cancel()
  481. q := getUsersInGroupsQuery(len(names))
  482. args := make([]any, 0, len(names))
  483. for _, name := range names {
  484. args = append(args, name)
  485. }
  486. usernames := make([]string, 0, len(names))
  487. rows, err := dbHandle.QueryContext(ctx, q, args...)
  488. if err != nil {
  489. return nil, err
  490. }
  491. defer rows.Close()
  492. for rows.Next() {
  493. var username string
  494. err = rows.Scan(&username)
  495. if err != nil {
  496. return usernames, err
  497. }
  498. usernames = append(usernames, username)
  499. }
  500. return usernames, rows.Err()
  501. }
  502. func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) {
  503. if len(names) == 0 {
  504. return nil, nil
  505. }
  506. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  507. defer cancel()
  508. q := getGroupsWithNamesQuery(len(names))
  509. args := make([]any, 0, len(names))
  510. for _, name := range names {
  511. args = append(args, name)
  512. }
  513. groups := make([]Group, 0, len(names))
  514. rows, err := dbHandle.QueryContext(ctx, q, args...)
  515. if err != nil {
  516. return groups, err
  517. }
  518. defer rows.Close()
  519. for rows.Next() {
  520. group, err := getGroupFromDbRow(rows)
  521. if err != nil {
  522. return groups, err
  523. }
  524. groups = append(groups, group)
  525. }
  526. err = rows.Err()
  527. if err != nil {
  528. return groups, err
  529. }
  530. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  531. }
  532. func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) {
  533. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  534. defer cancel()
  535. q := getGroupsQuery(order, minimal)
  536. groups := make([]Group, 0, limit)
  537. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  538. if err != nil {
  539. return groups, err
  540. }
  541. defer rows.Close()
  542. for rows.Next() {
  543. var group Group
  544. if minimal {
  545. err = rows.Scan(&group.ID, &group.Name)
  546. } else {
  547. group, err = getGroupFromDbRow(rows)
  548. }
  549. if err != nil {
  550. return groups, err
  551. }
  552. groups = append(groups, group)
  553. }
  554. err = rows.Err()
  555. if err != nil {
  556. return groups, err
  557. }
  558. if minimal {
  559. return groups, nil
  560. }
  561. groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  562. if err != nil {
  563. return groups, err
  564. }
  565. groups, err = getGroupsWithUsers(ctx, groups, dbHandle)
  566. if err != nil {
  567. return groups, err
  568. }
  569. for idx := range groups {
  570. groups[idx].PrepareForRendering()
  571. }
  572. return groups, nil
  573. }
  574. func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
  575. if err := group.validate(); err != nil {
  576. return err
  577. }
  578. settings, err := json.Marshal(group.UserSettings)
  579. if err != nil {
  580. return err
  581. }
  582. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  583. defer cancel()
  584. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  585. q := getAddGroupQuery()
  586. _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  587. util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
  588. if err != nil {
  589. return err
  590. }
  591. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  592. })
  593. }
  594. func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error {
  595. if err := group.validate(); err != nil {
  596. return err
  597. }
  598. settings, err := json.Marshal(group.UserSettings)
  599. if err != nil {
  600. return err
  601. }
  602. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  603. defer cancel()
  604. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  605. q := getUpdateGroupQuery()
  606. _, err := tx.ExecContext(ctx, q, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name)
  607. if err != nil {
  608. return err
  609. }
  610. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  611. })
  612. }
  613. func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
  614. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  615. defer cancel()
  616. q := getDeleteGroupQuery()
  617. res, err := dbHandle.ExecContext(ctx, q, group.Name)
  618. if err != nil {
  619. return err
  620. }
  621. return sqlCommonRequireRowAffected(res)
  622. }
  623. func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
  624. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  625. defer cancel()
  626. q := getUserByUsernameQuery()
  627. row := dbHandle.QueryRowContext(ctx, q, username)
  628. user, err := getUserFromDbRow(row)
  629. if err != nil {
  630. return user, err
  631. }
  632. user, err = getUserWithVirtualFolders(ctx, user, dbHandle)
  633. if err != nil {
  634. return user, err
  635. }
  636. return getUserWithGroups(ctx, user, dbHandle)
  637. }
  638. func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
  639. var user User
  640. if password == "" {
  641. return user, errors.New("credentials cannot be null or empty")
  642. }
  643. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  644. if err != nil {
  645. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  646. return user, err
  647. }
  648. return checkUserAndPass(&user, password, ip, protocol)
  649. }
  650. func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) {
  651. var user User
  652. if tlsCert == nil {
  653. return user, errors.New("TLS certificate cannot be null or empty")
  654. }
  655. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  656. if err != nil {
  657. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  658. return user, err
  659. }
  660. return checkUserAndTLSCertificate(&user, protocol, tlsCert)
  661. }
  662. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bool, dbHandle *sql.DB) (User, string, error) {
  663. var user User
  664. if len(pubKey) == 0 {
  665. return user, "", errors.New("credentials cannot be null or empty")
  666. }
  667. user, err := sqlCommonGetUserByUsername(username, dbHandle)
  668. if err != nil {
  669. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  670. return user, "", err
  671. }
  672. return checkUserAndPubKey(&user, pubKey, isSSHCert)
  673. }
  674. func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
  675. defer func() {
  676. if r := recover(); r != nil {
  677. providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
  678. err = errors.New("unable to check provider status")
  679. }
  680. }()
  681. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  682. defer cancel()
  683. err = dbHandle.PingContext(ctx)
  684. return
  685. }
  686. func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error {
  687. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  688. defer cancel()
  689. q := getUpdateTransferQuotaQuery(reset)
  690. _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  691. if err == nil {
  692. providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
  693. username, uploadSize, downloadSize, reset)
  694. } else {
  695. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  696. }
  697. return err
  698. }
  699. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  700. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  701. defer cancel()
  702. q := getUpdateQuotaQuery(reset)
  703. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  704. if err == nil {
  705. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  706. username, filesAdd, sizeAdd, reset)
  707. } else {
  708. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  709. }
  710. return err
  711. }
  712. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) {
  713. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  714. defer cancel()
  715. q := getQuotaQuery()
  716. var usedFiles int
  717. var usedSize, usedUploadSize, usedDownloadSize int64
  718. err := dbHandle.QueryRowContext(ctx, q, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize)
  719. if err != nil {
  720. providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err)
  721. return 0, 0, 0, 0, err
  722. }
  723. return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err
  724. }
  725. func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error {
  726. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  727. defer cancel()
  728. q := getUpdateShareLastUseQuery()
  729. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
  730. if err == nil {
  731. providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
  732. } else {
  733. providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
  734. }
  735. return err
  736. }
  737. func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
  738. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  739. defer cancel()
  740. q := getUpdateAPIKeyLastUseQuery()
  741. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
  742. if err == nil {
  743. providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
  744. } else {
  745. providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
  746. }
  747. return err
  748. }
  749. func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
  750. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  751. defer cancel()
  752. q := getUpdateAdminLastLoginQuery()
  753. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  754. if err == nil {
  755. providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
  756. } else {
  757. providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
  758. }
  759. return err
  760. }
  761. func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
  762. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  763. defer cancel()
  764. q := getSetUpdateAtQuery()
  765. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  766. if err == nil {
  767. providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
  768. } else {
  769. providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
  770. }
  771. }
  772. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  773. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  774. defer cancel()
  775. q := getUpdateLastLoginQuery()
  776. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  777. if err == nil {
  778. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  779. } else {
  780. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  781. }
  782. return err
  783. }
  784. func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
  785. err := ValidateUser(user)
  786. if err != nil {
  787. return err
  788. }
  789. permissions, err := user.GetPermissionsAsJSON()
  790. if err != nil {
  791. return err
  792. }
  793. publicKeys, err := user.GetPublicKeysAsJSON()
  794. if err != nil {
  795. return err
  796. }
  797. filters, err := user.GetFiltersAsJSON()
  798. if err != nil {
  799. return err
  800. }
  801. fsConfig, err := user.GetFsConfigAsJSON()
  802. if err != nil {
  803. return err
  804. }
  805. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  806. defer cancel()
  807. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  808. q := getAddUserQuery()
  809. _, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
  810. user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
  811. user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
  812. user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
  813. user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer)
  814. if err != nil {
  815. return err
  816. }
  817. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  818. return err
  819. }
  820. return generateUserGroupMapping(ctx, user, tx)
  821. })
  822. }
  823. func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) error {
  824. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  825. defer cancel()
  826. q := getUpdateUserPasswordQuery()
  827. _, err := dbHandle.ExecContext(ctx, q, password, username)
  828. return err
  829. }
  830. func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
  831. err := ValidateUser(user)
  832. if err != nil {
  833. return err
  834. }
  835. permissions, err := user.GetPermissionsAsJSON()
  836. if err != nil {
  837. return err
  838. }
  839. publicKeys, err := user.GetPublicKeysAsJSON()
  840. if err != nil {
  841. return err
  842. }
  843. filters, err := user.GetFiltersAsJSON()
  844. if err != nil {
  845. return err
  846. }
  847. fsConfig, err := user.GetFsConfigAsJSON()
  848. if err != nil {
  849. return err
  850. }
  851. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  852. defer cancel()
  853. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  854. q := getUpdateUserQuery()
  855. _, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
  856. user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
  857. user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
  858. util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
  859. user.ID)
  860. if err != nil {
  861. return err
  862. }
  863. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  864. return err
  865. }
  866. return generateUserGroupMapping(ctx, user, tx)
  867. })
  868. }
  869. func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
  870. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  871. defer cancel()
  872. q := getDeleteUserQuery(softDelete)
  873. if softDelete {
  874. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  875. if err := sqlCommonClearUserFolderMapping(ctx, &user, tx); err != nil {
  876. return err
  877. }
  878. if err := sqlCommonClearUserGroupMapping(ctx, &user, tx); err != nil {
  879. return err
  880. }
  881. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  882. res, err := tx.ExecContext(ctx, q, ts, ts, user.Username)
  883. if err != nil {
  884. return err
  885. }
  886. return sqlCommonRequireRowAffected(res)
  887. })
  888. }
  889. res, err := dbHandle.ExecContext(ctx, q, user.ID)
  890. if err != nil {
  891. return err
  892. }
  893. return sqlCommonRequireRowAffected(res)
  894. }
  895. func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
  896. users := make([]User, 0, 100)
  897. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  898. defer cancel()
  899. q := getDumpUsersQuery()
  900. rows, err := dbHandle.QueryContext(ctx, q)
  901. if err != nil {
  902. return users, err
  903. }
  904. defer rows.Close()
  905. for rows.Next() {
  906. u, err := getUserFromDbRow(rows)
  907. if err != nil {
  908. return users, err
  909. }
  910. users = append(users, u)
  911. }
  912. err = rows.Err()
  913. if err != nil {
  914. return users, err
  915. }
  916. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  917. if err != nil {
  918. return users, err
  919. }
  920. return getUsersWithGroups(ctx, users, dbHandle)
  921. }
  922. func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) {
  923. users := make([]User, 0, 10)
  924. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  925. defer cancel()
  926. q := getRecentlyUpdatedUsersQuery()
  927. rows, err := dbHandle.QueryContext(ctx, q, after)
  928. if err != nil {
  929. return users, err
  930. }
  931. defer rows.Close()
  932. for rows.Next() {
  933. u, err := getUserFromDbRow(rows)
  934. if err != nil {
  935. return users, err
  936. }
  937. users = append(users, u)
  938. }
  939. err = rows.Err()
  940. if err != nil {
  941. return users, err
  942. }
  943. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  944. if err != nil {
  945. return users, err
  946. }
  947. users, err = getUsersWithGroups(ctx, users, dbHandle)
  948. if err != nil {
  949. return users, err
  950. }
  951. var groupNames []string
  952. for _, u := range users {
  953. for _, g := range u.Groups {
  954. groupNames = append(groupNames, g.Name)
  955. }
  956. }
  957. groupNames = util.RemoveDuplicates(groupNames, false)
  958. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  959. if err != nil {
  960. return users, err
  961. }
  962. if len(groups) == 0 {
  963. return users, nil
  964. }
  965. groupsMapping := make(map[string]Group)
  966. for idx := range groups {
  967. groupsMapping[groups[idx].Name] = groups[idx]
  968. }
  969. for idx := range users {
  970. ref := &users[idx]
  971. ref.applyGroupSettings(groupsMapping)
  972. }
  973. return users, nil
  974. }
  975. func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
  976. users := make([]User, 0, 30)
  977. usernames := make([]string, 0, len(toFetch))
  978. for k := range toFetch {
  979. usernames = append(usernames, k)
  980. }
  981. maxUsers := 30
  982. for len(usernames) > 0 {
  983. if maxUsers > len(usernames) {
  984. maxUsers = len(usernames)
  985. }
  986. usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
  987. if err != nil {
  988. return users, err
  989. }
  990. users = append(users, usersRange...)
  991. usernames = usernames[maxUsers:]
  992. }
  993. var usersWithFolders []User
  994. validIdx := 0
  995. for _, user := range users {
  996. if toFetch[user.Username] {
  997. usersWithFolders = append(usersWithFolders, user)
  998. } else {
  999. users[validIdx] = user
  1000. validIdx++
  1001. }
  1002. }
  1003. users = users[:validIdx]
  1004. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1005. defer cancel()
  1006. usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
  1007. if err != nil {
  1008. return users, err
  1009. }
  1010. users = append(users, usersWithFolders...)
  1011. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1012. if err != nil {
  1013. return users, err
  1014. }
  1015. var groupNames []string
  1016. for _, u := range users {
  1017. for _, g := range u.Groups {
  1018. groupNames = append(groupNames, g.Name)
  1019. }
  1020. }
  1021. groupNames = util.RemoveDuplicates(groupNames, false)
  1022. if len(groupNames) == 0 {
  1023. return users, nil
  1024. }
  1025. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  1026. if err != nil {
  1027. return users, err
  1028. }
  1029. groupsMapping := make(map[string]Group)
  1030. for idx := range groups {
  1031. groupsMapping[groups[idx].Name] = groups[idx]
  1032. }
  1033. for idx := range users {
  1034. ref := &users[idx]
  1035. ref.applyGroupSettings(groupsMapping)
  1036. }
  1037. return users, nil
  1038. }
  1039. func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
  1040. users := make([]User, 0, len(usernames))
  1041. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1042. defer cancel()
  1043. q := getUsersForQuotaCheckQuery(len(usernames))
  1044. queryArgs := make([]any, 0, len(usernames))
  1045. for idx := range usernames {
  1046. queryArgs = append(queryArgs, usernames[idx])
  1047. }
  1048. rows, err := dbHandle.QueryContext(ctx, q, queryArgs...)
  1049. if err != nil {
  1050. return nil, err
  1051. }
  1052. defer rows.Close()
  1053. for rows.Next() {
  1054. var user User
  1055. var filters sql.NullString
  1056. err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer,
  1057. &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer,
  1058. &user.UsedDownloadDataTransfer, &filters)
  1059. if err != nil {
  1060. return users, err
  1061. }
  1062. if filters.Valid {
  1063. var userFilters UserFilters
  1064. err = json.Unmarshal([]byte(filters.String), &userFilters)
  1065. if err == nil {
  1066. user.Filters = userFilters
  1067. }
  1068. }
  1069. users = append(users, user)
  1070. }
  1071. return users, rows.Err()
  1072. }
  1073. func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error {
  1074. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1075. defer cancel()
  1076. q := getAddActiveTransferQuery()
  1077. now := util.GetTimeAsMsSinceEpoch(time.Now())
  1078. _, err := dbHandle.ExecContext(ctx, q, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username,
  1079. transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize,
  1080. now, now)
  1081. return err
  1082. }
  1083. func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error {
  1084. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1085. defer cancel()
  1086. q := getUpdateActiveTransferSizesQuery()
  1087. _, err := dbHandle.ExecContext(ctx, q, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID)
  1088. return err
  1089. }
  1090. func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error {
  1091. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1092. defer cancel()
  1093. q := getRemoveActiveTransferQuery()
  1094. _, err := dbHandle.ExecContext(ctx, q, connectionID, transferID)
  1095. return err
  1096. }
  1097. func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error {
  1098. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1099. defer cancel()
  1100. q := getCleanupActiveTransfersQuery()
  1101. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(before))
  1102. return err
  1103. }
  1104. func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) {
  1105. transfers := make([]ActiveTransfer, 0, 30)
  1106. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1107. defer cancel()
  1108. q := getActiveTransfersQuery()
  1109. rows, err := dbHandle.QueryContext(ctx, q, util.GetTimeAsMsSinceEpoch(from))
  1110. if err != nil {
  1111. return nil, err
  1112. }
  1113. defer rows.Close()
  1114. for rows.Next() {
  1115. var transfer ActiveTransfer
  1116. var folderName sql.NullString
  1117. err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP,
  1118. &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt,
  1119. &transfer.UpdatedAt)
  1120. if err != nil {
  1121. return transfers, err
  1122. }
  1123. if folderName.Valid {
  1124. transfer.FolderName = folderName.String
  1125. }
  1126. transfers = append(transfers, transfer)
  1127. }
  1128. return transfers, rows.Err()
  1129. }
  1130. func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
  1131. users := make([]User, 0, limit)
  1132. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1133. defer cancel()
  1134. q := getUsersQuery(order)
  1135. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  1136. if err != nil {
  1137. return users, err
  1138. }
  1139. defer rows.Close()
  1140. for rows.Next() {
  1141. u, err := getUserFromDbRow(rows)
  1142. if err != nil {
  1143. return users, err
  1144. }
  1145. users = append(users, u)
  1146. }
  1147. err = rows.Err()
  1148. if err != nil {
  1149. return users, err
  1150. }
  1151. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  1152. if err != nil {
  1153. return users, err
  1154. }
  1155. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1156. if err != nil {
  1157. return users, err
  1158. }
  1159. for idx := range users {
  1160. users[idx].PrepareForRendering()
  1161. }
  1162. return users, nil
  1163. }
  1164. func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) {
  1165. hosts := make([]DefenderEntry, 0, 100)
  1166. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1167. defer cancel()
  1168. q := getDefenderHostsQuery()
  1169. rows, err := dbHandle.QueryContext(ctx, q, from, limit)
  1170. if err != nil {
  1171. providerLog(logger.LevelError, "unable to get defender hosts: %v", err)
  1172. return hosts, err
  1173. }
  1174. defer rows.Close()
  1175. var idForScores []int64
  1176. for rows.Next() {
  1177. var banTime sql.NullInt64
  1178. host := DefenderEntry{}
  1179. err = rows.Scan(&host.ID, &host.IP, &banTime)
  1180. if err != nil {
  1181. providerLog(logger.LevelError, "unable to scan defender host row: %v", err)
  1182. return hosts, err
  1183. }
  1184. var hostBanTime time.Time
  1185. if banTime.Valid && banTime.Int64 > 0 {
  1186. hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1187. }
  1188. if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) {
  1189. idForScores = append(idForScores, host.ID)
  1190. } else {
  1191. host.BanTime = hostBanTime
  1192. }
  1193. hosts = append(hosts, host)
  1194. }
  1195. err = rows.Err()
  1196. if err != nil {
  1197. providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err)
  1198. return hosts, err
  1199. }
  1200. return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle)
  1201. }
  1202. func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (DefenderEntry, error) {
  1203. var host DefenderEntry
  1204. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1205. defer cancel()
  1206. q := getDefenderIsHostBannedQuery()
  1207. row := dbHandle.QueryRowContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1208. err := row.Scan(&host.ID)
  1209. if err != nil {
  1210. if errors.Is(err, sql.ErrNoRows) {
  1211. return host, util.NewRecordNotFoundError("host not found")
  1212. }
  1213. providerLog(logger.LevelError, "unable to check ban status for host %#v: %v", ip, err)
  1214. return host, err
  1215. }
  1216. return host, nil
  1217. }
  1218. func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (DefenderEntry, error) {
  1219. var host DefenderEntry
  1220. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1221. defer cancel()
  1222. q := getDefenderHostQuery()
  1223. row := dbHandle.QueryRowContext(ctx, q, ip, from)
  1224. var banTime sql.NullInt64
  1225. err := row.Scan(&host.ID, &host.IP, &banTime)
  1226. if err != nil {
  1227. if errors.Is(err, sql.ErrNoRows) {
  1228. return host, util.NewRecordNotFoundError("host not found")
  1229. }
  1230. providerLog(logger.LevelError, "unable to get host for ip %#v: %v", ip, err)
  1231. return host, err
  1232. }
  1233. if banTime.Valid && banTime.Int64 > 0 {
  1234. hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1235. if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) {
  1236. host.BanTime = hostBanTime
  1237. return host, nil
  1238. }
  1239. }
  1240. hosts, err := getDefenderHostsWithScores(ctx, []DefenderEntry{host}, from, []int64{host.ID}, dbHandle)
  1241. if err != nil {
  1242. return host, err
  1243. }
  1244. if len(hosts) == 0 {
  1245. return host, util.NewRecordNotFoundError("host not found")
  1246. }
  1247. return hosts[0], nil
  1248. }
  1249. func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error {
  1250. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1251. defer cancel()
  1252. q := getDefenderIncrementBanTimeQuery()
  1253. _, err := dbHandle.ExecContext(ctx, q, minutesToAdd*60000, ip)
  1254. if err == nil {
  1255. providerLog(logger.LevelDebug, "ban time updated for ip %#v, increment (minutes): %v",
  1256. ip, minutesToAdd)
  1257. } else {
  1258. providerLog(logger.LevelError, "error updating ban time for ip %#v: %v", ip, err)
  1259. }
  1260. return err
  1261. }
  1262. func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error {
  1263. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1264. defer cancel()
  1265. q := getDefenderSetBanTimeQuery()
  1266. _, err := dbHandle.ExecContext(ctx, q, banTime, ip)
  1267. if err == nil {
  1268. providerLog(logger.LevelDebug, "ip %#v banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime))
  1269. } else {
  1270. providerLog(logger.LevelError, "error setting ban time for ip %#v: %v", ip, err)
  1271. }
  1272. return err
  1273. }
  1274. func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
  1275. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1276. defer cancel()
  1277. q := getDeleteDefenderHostQuery()
  1278. res, err := dbHandle.ExecContext(ctx, q, ip)
  1279. if err != nil {
  1280. providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
  1281. return err
  1282. }
  1283. return sqlCommonRequireRowAffected(res)
  1284. }
  1285. func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
  1286. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1287. defer cancel()
  1288. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1289. if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil {
  1290. return err
  1291. }
  1292. return sqlCommonAddDefenderEvent(ctx, ip, score, tx)
  1293. })
  1294. }
  1295. func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error {
  1296. if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil {
  1297. return err
  1298. }
  1299. return sqlCommonCleanupDefenderHosts(from, dbHandler)
  1300. }
  1301. func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error {
  1302. q := getAddDefenderHostQuery()
  1303. _, err := tx.ExecContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1304. if err != nil {
  1305. providerLog(logger.LevelError, "unable to add defender host %#v: %v", ip, err)
  1306. }
  1307. return err
  1308. }
  1309. func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error {
  1310. q := getAddDefenderEventQuery()
  1311. _, err := tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip)
  1312. if err != nil {
  1313. providerLog(logger.LevelError, "unable to add defender event for %#v: %v", ip, err)
  1314. }
  1315. return err
  1316. }
  1317. func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error {
  1318. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1319. defer cancel()
  1320. q := getDefenderHostsCleanupQuery()
  1321. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), from)
  1322. if err != nil {
  1323. providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err)
  1324. }
  1325. return err
  1326. }
  1327. func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
  1328. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1329. defer cancel()
  1330. q := getDefenderEventsCleanupQuery()
  1331. _, err := dbHandle.ExecContext(ctx, q, from)
  1332. if err != nil {
  1333. providerLog(logger.LevelError, "unable to cleanup defender events: %v", err)
  1334. }
  1335. return err
  1336. }
  1337. func getShareFromDbRow(row sqlScanner) (Share, error) {
  1338. var share Share
  1339. var description, password, allowFrom, paths sql.NullString
  1340. err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope,
  1341. &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt,
  1342. &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens,
  1343. &share.UsedTokens, &allowFrom)
  1344. if err != nil {
  1345. if errors.Is(err, sql.ErrNoRows) {
  1346. return share, util.NewRecordNotFoundError(err.Error())
  1347. }
  1348. return share, err
  1349. }
  1350. if paths.Valid {
  1351. var list []string
  1352. err = json.Unmarshal([]byte(paths.String), &list)
  1353. if err != nil {
  1354. return share, err
  1355. }
  1356. share.Paths = list
  1357. } else {
  1358. return share, errors.New("unable to decode shared paths")
  1359. }
  1360. if description.Valid {
  1361. share.Description = description.String
  1362. }
  1363. if password.Valid {
  1364. share.Password = password.String
  1365. }
  1366. if allowFrom.Valid {
  1367. var list []string
  1368. err = json.Unmarshal([]byte(allowFrom.String), &list)
  1369. if err == nil {
  1370. share.AllowFrom = list
  1371. }
  1372. }
  1373. return share, nil
  1374. }
  1375. func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
  1376. var apiKey APIKey
  1377. var userID, adminID sql.NullInt64
  1378. var description sql.NullString
  1379. err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
  1380. &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
  1381. if err != nil {
  1382. if errors.Is(err, sql.ErrNoRows) {
  1383. return apiKey, util.NewRecordNotFoundError(err.Error())
  1384. }
  1385. return apiKey, err
  1386. }
  1387. if userID.Valid {
  1388. apiKey.userID = userID.Int64
  1389. }
  1390. if adminID.Valid {
  1391. apiKey.adminID = adminID.Int64
  1392. }
  1393. if description.Valid {
  1394. apiKey.Description = description.String
  1395. }
  1396. return apiKey, nil
  1397. }
  1398. func getAdminFromDbRow(row sqlScanner) (Admin, error) {
  1399. var admin Admin
  1400. var email, filters, additionalInfo, permissions, description sql.NullString
  1401. err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions,
  1402. &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin)
  1403. if err != nil {
  1404. if errors.Is(err, sql.ErrNoRows) {
  1405. return admin, util.NewRecordNotFoundError(err.Error())
  1406. }
  1407. return admin, err
  1408. }
  1409. if permissions.Valid {
  1410. var perms []string
  1411. err = json.Unmarshal([]byte(permissions.String), &perms)
  1412. if err != nil {
  1413. return admin, err
  1414. }
  1415. admin.Permissions = perms
  1416. }
  1417. if email.Valid {
  1418. admin.Email = email.String
  1419. }
  1420. if filters.Valid {
  1421. var adminFilters AdminFilters
  1422. err = json.Unmarshal([]byte(filters.String), &adminFilters)
  1423. if err == nil {
  1424. admin.Filters = adminFilters
  1425. }
  1426. }
  1427. if additionalInfo.Valid {
  1428. admin.AdditionalInfo = additionalInfo.String
  1429. }
  1430. if description.Valid {
  1431. admin.Description = description.String
  1432. }
  1433. admin.SetEmptySecretsIfNil()
  1434. return admin, nil
  1435. }
  1436. func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) {
  1437. var action BaseEventAction
  1438. var description sql.NullString
  1439. var options []byte
  1440. err := row.Scan(&action.ID, &action.Name, &description, &action.Type, &options)
  1441. if err != nil {
  1442. if errors.Is(err, sql.ErrNoRows) {
  1443. return action, util.NewRecordNotFoundError(err.Error())
  1444. }
  1445. return action, err
  1446. }
  1447. if description.Valid {
  1448. action.Description = description.String
  1449. }
  1450. if len(options) > 0 {
  1451. err = json.Unmarshal(options, &action.Options)
  1452. if err != nil {
  1453. return action, err
  1454. }
  1455. }
  1456. return action, nil
  1457. }
  1458. func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
  1459. var rule EventRule
  1460. var description sql.NullString
  1461. var conditions []byte
  1462. err := row.Scan(&rule.ID, &rule.Name, &description, &rule.CreatedAt, &rule.UpdatedAt, &rule.Trigger,
  1463. &conditions, &rule.DeletedAt)
  1464. if err != nil {
  1465. if errors.Is(err, sql.ErrNoRows) {
  1466. return rule, util.NewRecordNotFoundError(err.Error())
  1467. }
  1468. return rule, err
  1469. }
  1470. if len(conditions) > 0 {
  1471. err = json.Unmarshal(conditions, &rule.Conditions)
  1472. if err != nil {
  1473. return rule, err
  1474. }
  1475. }
  1476. if description.Valid {
  1477. rule.Description = description.String
  1478. }
  1479. return rule, nil
  1480. }
  1481. func getGroupFromDbRow(row sqlScanner) (Group, error) {
  1482. var group Group
  1483. var userSettings, description sql.NullString
  1484. err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings)
  1485. if err != nil {
  1486. if errors.Is(err, sql.ErrNoRows) {
  1487. return group, util.NewRecordNotFoundError(err.Error())
  1488. }
  1489. return group, err
  1490. }
  1491. if description.Valid {
  1492. group.Description = description.String
  1493. }
  1494. if userSettings.Valid {
  1495. var settings GroupUserSettings
  1496. err = json.Unmarshal([]byte(userSettings.String), &settings)
  1497. if err == nil {
  1498. group.UserSettings = settings
  1499. }
  1500. }
  1501. return group, nil
  1502. }
  1503. func getUserFromDbRow(row sqlScanner) (User, error) {
  1504. var user User
  1505. var permissions sql.NullString
  1506. var password sql.NullString
  1507. var publicKey sql.NullString
  1508. var filters sql.NullString
  1509. var fsConfig sql.NullString
  1510. var additionalInfo, description, email sql.NullString
  1511. err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  1512. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  1513. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  1514. &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer,
  1515. &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt)
  1516. if err != nil {
  1517. if errors.Is(err, sql.ErrNoRows) {
  1518. return user, util.NewRecordNotFoundError(err.Error())
  1519. }
  1520. return user, err
  1521. }
  1522. if password.Valid {
  1523. user.Password = password.String
  1524. }
  1525. // we can have a empty string or an invalid json in null string
  1526. // so we do a relaxed test if the field is optional, for example we
  1527. // populate public keys only if unmarshal does not return an error
  1528. if publicKey.Valid {
  1529. var list []string
  1530. err = json.Unmarshal([]byte(publicKey.String), &list)
  1531. if err == nil {
  1532. user.PublicKeys = list
  1533. }
  1534. }
  1535. if permissions.Valid {
  1536. perms := make(map[string][]string)
  1537. err = json.Unmarshal([]byte(permissions.String), &perms)
  1538. if err != nil {
  1539. providerLog(logger.LevelError, "unable to deserialize permissions for user %#v: %v", user.Username, err)
  1540. return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
  1541. }
  1542. user.Permissions = perms
  1543. }
  1544. if filters.Valid {
  1545. var userFilters UserFilters
  1546. err = json.Unmarshal([]byte(filters.String), &userFilters)
  1547. if err == nil {
  1548. user.Filters = userFilters
  1549. }
  1550. }
  1551. if fsConfig.Valid {
  1552. var fs vfs.Filesystem
  1553. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1554. if err == nil {
  1555. user.FsConfig = fs
  1556. }
  1557. }
  1558. if additionalInfo.Valid {
  1559. user.AdditionalInfo = additionalInfo.String
  1560. }
  1561. if description.Valid {
  1562. user.Description = description.String
  1563. }
  1564. if email.Valid {
  1565. user.Email = email.String
  1566. }
  1567. user.SetEmptySecretsIfNil()
  1568. return user, nil
  1569. }
  1570. func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1571. var folder vfs.BaseVirtualFolder
  1572. q := getFolderByNameQuery()
  1573. row := dbHandle.QueryRowContext(ctx, q, name)
  1574. var mappedPath, description, fsConfig sql.NullString
  1575. err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
  1576. &folder.Name, &description, &fsConfig)
  1577. if err != nil {
  1578. if errors.Is(err, sql.ErrNoRows) {
  1579. return folder, util.NewRecordNotFoundError(err.Error())
  1580. }
  1581. return folder, err
  1582. }
  1583. if mappedPath.Valid {
  1584. folder.MappedPath = mappedPath.String
  1585. }
  1586. if description.Valid {
  1587. folder.Description = description.String
  1588. }
  1589. if fsConfig.Valid {
  1590. var fs vfs.Filesystem
  1591. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1592. if err == nil {
  1593. folder.FsConfig = fs
  1594. }
  1595. }
  1596. return folder, err
  1597. }
  1598. func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1599. folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
  1600. if err != nil {
  1601. return folder, err
  1602. }
  1603. folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle)
  1604. if err != nil {
  1605. return folder, err
  1606. }
  1607. if len(folders) != 1 {
  1608. return folder, fmt.Errorf("unable to associate users with folder %#v", name)
  1609. }
  1610. folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle)
  1611. if err != nil {
  1612. return folder, err
  1613. }
  1614. if len(folders) != 1 {
  1615. return folder, fmt.Errorf("unable to associate groups with folder %#v", name)
  1616. }
  1617. return folders[0], nil
  1618. }
  1619. func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
  1620. usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier,
  1621. ) error {
  1622. fsConfig, err := json.Marshal(baseFolder.FsConfig)
  1623. if err != nil {
  1624. return err
  1625. }
  1626. q := getUpsertFolderQuery()
  1627. _, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
  1628. lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
  1629. return err
  1630. }
  1631. func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1632. err := ValidateFolder(folder)
  1633. if err != nil {
  1634. return err
  1635. }
  1636. fsConfig, err := json.Marshal(folder.FsConfig)
  1637. if err != nil {
  1638. return err
  1639. }
  1640. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1641. defer cancel()
  1642. q := getAddFolderQuery()
  1643. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
  1644. folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
  1645. return err
  1646. }
  1647. func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1648. err := ValidateFolder(folder)
  1649. if err != nil {
  1650. return err
  1651. }
  1652. fsConfig, err := json.Marshal(folder.FsConfig)
  1653. if err != nil {
  1654. return err
  1655. }
  1656. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1657. defer cancel()
  1658. q := getUpdateFolderQuery()
  1659. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
  1660. return err
  1661. }
  1662. func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1663. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1664. defer cancel()
  1665. q := getDeleteFolderQuery()
  1666. res, err := dbHandle.ExecContext(ctx, q, folder.ID)
  1667. if err != nil {
  1668. return err
  1669. }
  1670. return sqlCommonRequireRowAffected(res)
  1671. }
  1672. func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1673. folders := make([]vfs.BaseVirtualFolder, 0, 50)
  1674. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1675. defer cancel()
  1676. q := getDumpFoldersQuery()
  1677. rows, err := dbHandle.QueryContext(ctx, q)
  1678. if err != nil {
  1679. return folders, err
  1680. }
  1681. defer rows.Close()
  1682. for rows.Next() {
  1683. var folder vfs.BaseVirtualFolder
  1684. var mappedPath, description, fsConfig sql.NullString
  1685. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1686. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1687. if err != nil {
  1688. return folders, err
  1689. }
  1690. if mappedPath.Valid {
  1691. folder.MappedPath = mappedPath.String
  1692. }
  1693. if description.Valid {
  1694. folder.Description = description.String
  1695. }
  1696. if fsConfig.Valid {
  1697. var fs vfs.Filesystem
  1698. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1699. if err == nil {
  1700. folder.FsConfig = fs
  1701. }
  1702. }
  1703. folders = append(folders, folder)
  1704. }
  1705. return folders, rows.Err()
  1706. }
  1707. func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1708. folders := make([]vfs.BaseVirtualFolder, 0, limit)
  1709. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1710. defer cancel()
  1711. q := getFoldersQuery(order, minimal)
  1712. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  1713. if err != nil {
  1714. return folders, err
  1715. }
  1716. defer rows.Close()
  1717. for rows.Next() {
  1718. var folder vfs.BaseVirtualFolder
  1719. if minimal {
  1720. err = rows.Scan(&folder.ID, &folder.Name)
  1721. if err != nil {
  1722. return folders, err
  1723. }
  1724. } else {
  1725. var mappedPath, description, fsConfig sql.NullString
  1726. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1727. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1728. if err != nil {
  1729. return folders, err
  1730. }
  1731. if mappedPath.Valid {
  1732. folder.MappedPath = mappedPath.String
  1733. }
  1734. if description.Valid {
  1735. folder.Description = description.String
  1736. }
  1737. if fsConfig.Valid {
  1738. var fs vfs.Filesystem
  1739. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1740. if err == nil {
  1741. folder.FsConfig = fs
  1742. }
  1743. }
  1744. }
  1745. folder.PrepareForRendering()
  1746. folders = append(folders, folder)
  1747. }
  1748. err = rows.Err()
  1749. if err != nil {
  1750. return folders, err
  1751. }
  1752. if minimal {
  1753. return folders, nil
  1754. }
  1755. folders, err = getVirtualFoldersWithUsers(folders, dbHandle)
  1756. if err != nil {
  1757. return folders, err
  1758. }
  1759. return getVirtualFoldersWithGroups(folders, dbHandle)
  1760. }
  1761. func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1762. q := getClearUserFolderMappingQuery()
  1763. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1764. return err
  1765. }
  1766. func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1767. q := getClearGroupFolderMappingQuery()
  1768. _, err := dbHandle.ExecContext(ctx, q, group.Name)
  1769. return err
  1770. }
  1771. func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1772. q := getClearUserGroupMappingQuery()
  1773. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1774. return err
  1775. }
  1776. func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1777. q := getAddUserFolderMappingQuery()
  1778. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username)
  1779. return err
  1780. }
  1781. func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1782. q := getAddGroupFolderMappingQuery()
  1783. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name)
  1784. return err
  1785. }
  1786. func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType int, dbHandle sqlQuerier) error {
  1787. q := getAddUserGroupMappingQuery()
  1788. _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType)
  1789. return err
  1790. }
  1791. func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1792. err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle)
  1793. if err != nil {
  1794. return err
  1795. }
  1796. for idx := range group.VirtualFolders {
  1797. vfolder := &group.VirtualFolders[idx]
  1798. err = sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1799. if err != nil {
  1800. return err
  1801. }
  1802. err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
  1803. if err != nil {
  1804. return err
  1805. }
  1806. }
  1807. return err
  1808. }
  1809. func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1810. err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle)
  1811. if err != nil {
  1812. return err
  1813. }
  1814. for idx := range user.VirtualFolders {
  1815. vfolder := &user.VirtualFolders[idx]
  1816. err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1817. if err != nil {
  1818. return err
  1819. }
  1820. err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
  1821. if err != nil {
  1822. return err
  1823. }
  1824. }
  1825. return err
  1826. }
  1827. func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1828. err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle)
  1829. if err != nil {
  1830. return err
  1831. }
  1832. for _, group := range user.Groups {
  1833. err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, dbHandle)
  1834. if err != nil {
  1835. return err
  1836. }
  1837. }
  1838. return err
  1839. }
  1840. func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64,
  1841. dbHandle sqlQuerier) (
  1842. []DefenderEntry,
  1843. error,
  1844. ) {
  1845. if len(idForScores) == 0 {
  1846. return hosts, nil
  1847. }
  1848. hostsWithScores := make(map[int64]int)
  1849. q := getDefenderEventsQuery(idForScores)
  1850. rows, err := dbHandle.QueryContext(ctx, q, from)
  1851. if err != nil {
  1852. providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err)
  1853. return nil, err
  1854. }
  1855. defer rows.Close()
  1856. for rows.Next() {
  1857. var hostID int64
  1858. var score int
  1859. err = rows.Scan(&hostID, &score)
  1860. if err != nil {
  1861. providerLog(logger.LevelError, "error scanning host score row: %v", err)
  1862. return hosts, err
  1863. }
  1864. if score > 0 {
  1865. hostsWithScores[hostID] = score
  1866. }
  1867. }
  1868. err = rows.Err()
  1869. if err != nil {
  1870. return hosts, err
  1871. }
  1872. result := make([]DefenderEntry, 0, len(hosts))
  1873. for idx := range hosts {
  1874. hosts[idx].Score = hostsWithScores[hosts[idx].ID]
  1875. if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() {
  1876. result = append(result, hosts[idx])
  1877. }
  1878. }
  1879. return result, nil
  1880. }
  1881. func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  1882. users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
  1883. if err != nil {
  1884. return user, err
  1885. }
  1886. if len(users) == 0 {
  1887. return user, errSQLFoldersAssociation
  1888. }
  1889. return users[0], err
  1890. }
  1891. func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  1892. if len(users) == 0 {
  1893. return users, nil
  1894. }
  1895. usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  1896. q := getRelatedFoldersForUsersQuery(users)
  1897. rows, err := dbHandle.QueryContext(ctx, q)
  1898. if err != nil {
  1899. return nil, err
  1900. }
  1901. defer rows.Close()
  1902. for rows.Next() {
  1903. var folder vfs.VirtualFolder
  1904. var userID int64
  1905. var mappedPath, fsConfig, description sql.NullString
  1906. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1907. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
  1908. &description)
  1909. if err != nil {
  1910. return users, err
  1911. }
  1912. if mappedPath.Valid {
  1913. folder.MappedPath = mappedPath.String
  1914. }
  1915. if description.Valid {
  1916. folder.Description = description.String
  1917. }
  1918. if fsConfig.Valid {
  1919. var fs vfs.Filesystem
  1920. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  1921. if err == nil {
  1922. folder.FsConfig = fs
  1923. }
  1924. }
  1925. usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
  1926. }
  1927. err = rows.Err()
  1928. if err != nil {
  1929. return users, err
  1930. }
  1931. if len(usersVirtualFolders) == 0 {
  1932. return users, err
  1933. }
  1934. for idx := range users {
  1935. ref := &users[idx]
  1936. ref.VirtualFolders = usersVirtualFolders[ref.ID]
  1937. }
  1938. return users, err
  1939. }
  1940. func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  1941. users, err := getUsersWithGroups(ctx, []User{user}, dbHandle)
  1942. if err != nil {
  1943. return user, err
  1944. }
  1945. if len(users) == 0 {
  1946. return user, errSQLGroupsAssociation
  1947. }
  1948. return users[0], err
  1949. }
  1950. func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  1951. if len(users) == 0 {
  1952. return users, nil
  1953. }
  1954. usersGroups := make(map[int64][]sdk.GroupMapping)
  1955. q := getRelatedGroupsForUsersQuery(users)
  1956. rows, err := dbHandle.QueryContext(ctx, q)
  1957. if err != nil {
  1958. return nil, err
  1959. }
  1960. defer rows.Close()
  1961. for rows.Next() {
  1962. var group sdk.GroupMapping
  1963. var userID int64
  1964. err = rows.Scan(&group.Name, &group.Type, &userID)
  1965. if err != nil {
  1966. return users, err
  1967. }
  1968. usersGroups[userID] = append(usersGroups[userID], group)
  1969. }
  1970. err = rows.Err()
  1971. if err != nil {
  1972. return users, err
  1973. }
  1974. if len(usersGroups) == 0 {
  1975. return users, err
  1976. }
  1977. for idx := range users {
  1978. ref := &users[idx]
  1979. ref.Groups = usersGroups[ref.ID]
  1980. }
  1981. return users, err
  1982. }
  1983. func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  1984. groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle)
  1985. if err != nil {
  1986. return group, err
  1987. }
  1988. if len(groups) == 0 {
  1989. return group, errSQLUsersAssociation
  1990. }
  1991. return groups[0], err
  1992. }
  1993. func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  1994. groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle)
  1995. if err != nil {
  1996. return group, err
  1997. }
  1998. if len(groups) == 0 {
  1999. return group, errSQLFoldersAssociation
  2000. }
  2001. return groups[0], err
  2002. }
  2003. func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2004. if len(groups) == 0 {
  2005. return groups, nil
  2006. }
  2007. q := getRelatedFoldersForGroupsQuery(groups)
  2008. rows, err := dbHandle.QueryContext(ctx, q)
  2009. if err != nil {
  2010. return nil, err
  2011. }
  2012. defer rows.Close()
  2013. groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  2014. for rows.Next() {
  2015. var groupID int64
  2016. var folder vfs.VirtualFolder
  2017. var mappedPath, fsConfig, description sql.NullString
  2018. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  2019. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig,
  2020. &description)
  2021. if err != nil {
  2022. return groups, err
  2023. }
  2024. if mappedPath.Valid {
  2025. folder.MappedPath = mappedPath.String
  2026. }
  2027. if description.Valid {
  2028. folder.Description = description.String
  2029. }
  2030. if fsConfig.Valid {
  2031. var fs vfs.Filesystem
  2032. err = json.Unmarshal([]byte(fsConfig.String), &fs)
  2033. if err == nil {
  2034. folder.FsConfig = fs
  2035. }
  2036. }
  2037. groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder)
  2038. }
  2039. err = rows.Err()
  2040. if err != nil {
  2041. return groups, err
  2042. }
  2043. if len(groupsVirtualFolders) == 0 {
  2044. return groups, err
  2045. }
  2046. for idx := range groups {
  2047. ref := &groups[idx]
  2048. ref.VirtualFolders = groupsVirtualFolders[ref.ID]
  2049. }
  2050. return groups, err
  2051. }
  2052. func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2053. if len(groups) == 0 {
  2054. return groups, nil
  2055. }
  2056. q := getRelatedUsersForGroupsQuery(groups)
  2057. rows, err := dbHandle.QueryContext(ctx, q)
  2058. if err != nil {
  2059. return nil, err
  2060. }
  2061. defer rows.Close()
  2062. groupsUsers := make(map[int64][]string)
  2063. for rows.Next() {
  2064. var username string
  2065. var groupID int64
  2066. err = rows.Scan(&groupID, &username)
  2067. if err != nil {
  2068. return groups, err
  2069. }
  2070. groupsUsers[groupID] = append(groupsUsers[groupID], username)
  2071. }
  2072. err = rows.Err()
  2073. if err != nil {
  2074. return groups, err
  2075. }
  2076. if len(groupsUsers) == 0 {
  2077. return groups, err
  2078. }
  2079. for idx := range groups {
  2080. ref := &groups[idx]
  2081. ref.Users = groupsUsers[ref.ID]
  2082. }
  2083. return groups, err
  2084. }
  2085. func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2086. if len(folders) == 0 {
  2087. return folders, nil
  2088. }
  2089. vFoldersGroups := make(map[int64][]string)
  2090. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2091. defer cancel()
  2092. q := getRelatedGroupsForFoldersQuery(folders)
  2093. rows, err := dbHandle.QueryContext(ctx, q)
  2094. if err != nil {
  2095. return nil, err
  2096. }
  2097. defer rows.Close()
  2098. for rows.Next() {
  2099. var name string
  2100. var folderID int64
  2101. err = rows.Scan(&folderID, &name)
  2102. if err != nil {
  2103. return folders, err
  2104. }
  2105. vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name)
  2106. }
  2107. err = rows.Err()
  2108. if err != nil {
  2109. return folders, err
  2110. }
  2111. if len(vFoldersGroups) == 0 {
  2112. return folders, err
  2113. }
  2114. for idx := range folders {
  2115. ref := &folders[idx]
  2116. ref.Groups = vFoldersGroups[ref.ID]
  2117. }
  2118. return folders, err
  2119. }
  2120. func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2121. if len(folders) == 0 {
  2122. return folders, nil
  2123. }
  2124. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2125. defer cancel()
  2126. q := getRelatedUsersForFoldersQuery(folders)
  2127. rows, err := dbHandle.QueryContext(ctx, q)
  2128. if err != nil {
  2129. return nil, err
  2130. }
  2131. defer rows.Close()
  2132. vFoldersUsers := make(map[int64][]string)
  2133. for rows.Next() {
  2134. var username string
  2135. var folderID int64
  2136. err = rows.Scan(&folderID, &username)
  2137. if err != nil {
  2138. return folders, err
  2139. }
  2140. vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username)
  2141. }
  2142. err = rows.Err()
  2143. if err != nil {
  2144. return folders, err
  2145. }
  2146. if len(vFoldersUsers) == 0 {
  2147. return folders, err
  2148. }
  2149. for idx := range folders {
  2150. ref := &folders[idx]
  2151. ref.Users = vFoldersUsers[ref.ID]
  2152. }
  2153. return folders, err
  2154. }
  2155. func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  2156. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2157. defer cancel()
  2158. q := getUpdateFolderQuotaQuery(reset)
  2159. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  2160. if err == nil {
  2161. providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
  2162. name, filesAdd, sizeAdd, reset)
  2163. } else {
  2164. providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
  2165. }
  2166. return err
  2167. }
  2168. func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) {
  2169. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2170. defer cancel()
  2171. q := getQuotaFolderQuery()
  2172. var usedFiles int
  2173. var usedSize int64
  2174. err := dbHandle.QueryRowContext(ctx, q, mappedPath).Scan(&usedSize, &usedFiles)
  2175. if err != nil {
  2176. providerLog(logger.LevelError, "error getting quota for folder: %v, error: %v", mappedPath, err)
  2177. return 0, 0, err
  2178. }
  2179. return usedFiles, usedSize, err
  2180. }
  2181. func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
  2182. var apiKeys []APIKey
  2183. var err error
  2184. scope := APIKeyScopeAdmin
  2185. if apiKey.userID > 0 {
  2186. scope = APIKeyScopeUser
  2187. }
  2188. apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
  2189. if err != nil {
  2190. return apiKey, err
  2191. }
  2192. if len(apiKeys) > 0 {
  2193. apiKey = apiKeys[0]
  2194. }
  2195. return apiKey, nil
  2196. }
  2197. func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
  2198. if len(apiKeys) == 0 {
  2199. return apiKeys, nil
  2200. }
  2201. values := make(map[int64]string)
  2202. var q string
  2203. if scope == APIKeyScopeUser {
  2204. q = getRelatedUsersForAPIKeysQuery(apiKeys)
  2205. } else {
  2206. q = getRelatedAdminsForAPIKeysQuery(apiKeys)
  2207. }
  2208. rows, err := dbHandle.QueryContext(ctx, q)
  2209. if err != nil {
  2210. return nil, err
  2211. }
  2212. defer rows.Close()
  2213. for rows.Next() {
  2214. var valueID int64
  2215. var valueName string
  2216. err = rows.Scan(&valueID, &valueName)
  2217. if err != nil {
  2218. return apiKeys, err
  2219. }
  2220. values[valueID] = valueName
  2221. }
  2222. err = rows.Err()
  2223. if err != nil {
  2224. return apiKeys, err
  2225. }
  2226. if len(values) == 0 {
  2227. return apiKeys, nil
  2228. }
  2229. for idx := range apiKeys {
  2230. ref := &apiKeys[idx]
  2231. if scope == APIKeyScopeUser {
  2232. ref.User = values[ref.userID]
  2233. } else {
  2234. ref.Admin = values[ref.adminID]
  2235. }
  2236. }
  2237. return apiKeys, nil
  2238. }
  2239. func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
  2240. var userID, adminID sql.NullInt64
  2241. if apiKey.User != "" {
  2242. u, err := provider.userExists(apiKey.User)
  2243. if err != nil {
  2244. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
  2245. }
  2246. userID.Valid = true
  2247. userID.Int64 = u.ID
  2248. }
  2249. if apiKey.Admin != "" {
  2250. a, err := provider.adminExists(apiKey.Admin)
  2251. if err != nil {
  2252. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
  2253. }
  2254. adminID.Valid = true
  2255. adminID.Int64 = a.ID
  2256. }
  2257. return userID, adminID, nil
  2258. }
  2259. func sqlCommonAddSession(session Session, dbHandle *sql.DB) error {
  2260. if err := session.validate(); err != nil {
  2261. return err
  2262. }
  2263. data, err := json.Marshal(session.Data)
  2264. if err != nil {
  2265. return err
  2266. }
  2267. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2268. defer cancel()
  2269. q := getAddSessionQuery()
  2270. _, err = dbHandle.ExecContext(ctx, q, session.Key, data, session.Type, session.Timestamp)
  2271. return err
  2272. }
  2273. func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) {
  2274. var session Session
  2275. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2276. defer cancel()
  2277. q := getSessionQuery()
  2278. var data []byte // type hint, some driver will use string instead of []byte if the type is any
  2279. err := dbHandle.QueryRowContext(ctx, q, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp)
  2280. if err != nil {
  2281. return session, err
  2282. }
  2283. session.Data = data
  2284. return session, nil
  2285. }
  2286. func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error {
  2287. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2288. defer cancel()
  2289. q := getDeleteSessionQuery()
  2290. res, err := dbHandle.ExecContext(ctx, q, key)
  2291. if err != nil {
  2292. return err
  2293. }
  2294. return sqlCommonRequireRowAffected(res)
  2295. }
  2296. func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error {
  2297. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2298. defer cancel()
  2299. q := getCleanupSessionsQuery()
  2300. _, err := dbHandle.ExecContext(ctx, q, sessionType, before)
  2301. return err
  2302. }
  2303. func getActionsWithRuleNames(ctx context.Context, actions []BaseEventAction, dbHandle sqlQuerier,
  2304. ) ([]BaseEventAction, error) {
  2305. if len(actions) == 0 {
  2306. return actions, nil
  2307. }
  2308. q := getRelatedRulesForActionsQuery(actions)
  2309. rows, err := dbHandle.QueryContext(ctx, q)
  2310. if err != nil {
  2311. return nil, err
  2312. }
  2313. defer rows.Close()
  2314. actionsRules := make(map[int64][]string)
  2315. for rows.Next() {
  2316. var name string
  2317. var actionID int64
  2318. if err = rows.Scan(&actionID, &name); err != nil {
  2319. return nil, err
  2320. }
  2321. actionsRules[actionID] = append(actionsRules[actionID], name)
  2322. }
  2323. err = rows.Err()
  2324. if err != nil {
  2325. return nil, err
  2326. }
  2327. if len(actionsRules) == 0 {
  2328. return actions, nil
  2329. }
  2330. for idx := range actions {
  2331. ref := &actions[idx]
  2332. ref.Rules = actionsRules[ref.ID]
  2333. }
  2334. return actions, nil
  2335. }
  2336. func getRulesWithActions(ctx context.Context, rules []EventRule, dbHandle sqlQuerier) ([]EventRule, error) {
  2337. if len(rules) == 0 {
  2338. return rules, nil
  2339. }
  2340. rulesActions := make(map[int64][]EventAction)
  2341. q := getRelatedActionsForRulesQuery(rules)
  2342. rows, err := dbHandle.QueryContext(ctx, q)
  2343. if err != nil {
  2344. return nil, err
  2345. }
  2346. defer rows.Close()
  2347. for rows.Next() {
  2348. var action EventAction
  2349. var ruleID int64
  2350. var description sql.NullString
  2351. var baseOptions, options []byte
  2352. err = rows.Scan(&action.ID, &action.Name, &description, &action.Type, &baseOptions, &options,
  2353. &action.Order, &ruleID)
  2354. if err != nil {
  2355. return rules, err
  2356. }
  2357. if len(baseOptions) > 0 {
  2358. err = json.Unmarshal(baseOptions, &action.BaseEventAction.Options)
  2359. if err != nil {
  2360. return rules, err
  2361. }
  2362. }
  2363. if len(options) > 0 {
  2364. err = json.Unmarshal(options, &action.Options)
  2365. if err != nil {
  2366. return rules, err
  2367. }
  2368. }
  2369. action.BaseEventAction.Options.SetEmptySecretsIfNil()
  2370. rulesActions[ruleID] = append(rulesActions[ruleID], action)
  2371. }
  2372. err = rows.Err()
  2373. if err != nil {
  2374. return rules, err
  2375. }
  2376. if len(rulesActions) == 0 {
  2377. return rules, nil
  2378. }
  2379. for idx := range rules {
  2380. ref := &rules[idx]
  2381. ref.Actions = rulesActions[ref.ID]
  2382. }
  2383. return rules, nil
  2384. }
  2385. func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHandle sqlQuerier) error {
  2386. q := getClearRuleActionMappingQuery()
  2387. _, err := dbHandle.ExecContext(ctx, q, rule.Name)
  2388. if err != nil {
  2389. return err
  2390. }
  2391. for _, action := range rule.Actions {
  2392. options, err := json.Marshal(action.Options)
  2393. if err != nil {
  2394. return err
  2395. }
  2396. q = getAddRuleActionMappingQuery()
  2397. _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
  2398. if err != nil {
  2399. return err
  2400. }
  2401. }
  2402. return nil
  2403. }
  2404. func sqlCommonGetEventActionByName(name string, dbHandle sqlQuerier) (BaseEventAction, error) {
  2405. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2406. defer cancel()
  2407. q := getEventActionByNameQuery()
  2408. row := dbHandle.QueryRowContext(ctx, q, name)
  2409. action, err := getEventActionFromDbRow(row)
  2410. if err != nil {
  2411. return action, err
  2412. }
  2413. actions, err := getActionsWithRuleNames(ctx, []BaseEventAction{action}, dbHandle)
  2414. if err != nil {
  2415. return action, err
  2416. }
  2417. if len(actions) != 1 {
  2418. return action, fmt.Errorf("unable to associate rules with action %q", name)
  2419. }
  2420. return actions[0], nil
  2421. }
  2422. func sqlCommonDumpEventActions(dbHandle sqlQuerier) ([]BaseEventAction, error) {
  2423. actions := make([]BaseEventAction, 0, 10)
  2424. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2425. defer cancel()
  2426. q := getDumpEventActionsQuery()
  2427. rows, err := dbHandle.QueryContext(ctx, q)
  2428. if err != nil {
  2429. return actions, err
  2430. }
  2431. defer rows.Close()
  2432. for rows.Next() {
  2433. action, err := getEventActionFromDbRow(rows)
  2434. if err != nil {
  2435. return actions, err
  2436. }
  2437. actions = append(actions, action)
  2438. }
  2439. return actions, rows.Err()
  2440. }
  2441. func sqlCommonGetEventActions(limit int, offset int, order string, minimal bool,
  2442. dbHandle sqlQuerier,
  2443. ) ([]BaseEventAction, error) {
  2444. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2445. defer cancel()
  2446. q := getEventsActionsQuery(order, minimal)
  2447. actions := make([]BaseEventAction, 0, limit)
  2448. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2449. if err != nil {
  2450. return actions, err
  2451. }
  2452. defer rows.Close()
  2453. for rows.Next() {
  2454. var action BaseEventAction
  2455. if minimal {
  2456. err = rows.Scan(&action.ID, &action.Name)
  2457. } else {
  2458. action, err = getEventActionFromDbRow(rows)
  2459. }
  2460. if err != nil {
  2461. return actions, err
  2462. }
  2463. actions = append(actions, action)
  2464. }
  2465. err = rows.Err()
  2466. if err != nil {
  2467. return nil, err
  2468. }
  2469. if minimal {
  2470. return actions, nil
  2471. }
  2472. actions, err = getActionsWithRuleNames(ctx, actions, dbHandle)
  2473. if err != nil {
  2474. return nil, err
  2475. }
  2476. for idx := range actions {
  2477. actions[idx].PrepareForRendering()
  2478. }
  2479. return actions, nil
  2480. }
  2481. func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2482. if err := action.validate(); err != nil {
  2483. return err
  2484. }
  2485. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2486. defer cancel()
  2487. q := getAddEventActionQuery()
  2488. options, err := json.Marshal(action.Options)
  2489. if err != nil {
  2490. return err
  2491. }
  2492. _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
  2493. return err
  2494. }
  2495. func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2496. if err := action.validate(); err != nil {
  2497. return err
  2498. }
  2499. options, err := json.Marshal(action.Options)
  2500. if err != nil {
  2501. return err
  2502. }
  2503. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2504. defer cancel()
  2505. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2506. q := getUpdateEventActionQuery()
  2507. _, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
  2508. if err != nil {
  2509. return err
  2510. }
  2511. q = getUpdateRulesTimestampQuery()
  2512. _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
  2513. return err
  2514. })
  2515. }
  2516. func sqlCommonDeleteEventAction(action BaseEventAction, dbHandle *sql.DB) error {
  2517. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2518. defer cancel()
  2519. q := getDeleteEventActionQuery()
  2520. res, err := dbHandle.ExecContext(ctx, q, action.Name)
  2521. if err != nil {
  2522. return err
  2523. }
  2524. return sqlCommonRequireRowAffected(res)
  2525. }
  2526. func sqlCommonGetEventRuleByName(name string, dbHandle sqlQuerier) (EventRule, error) {
  2527. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2528. defer cancel()
  2529. q := getEventRulesByNameQuery()
  2530. row := dbHandle.QueryRowContext(ctx, q, name)
  2531. rule, err := getEventRuleFromDbRow(row)
  2532. if err != nil {
  2533. return rule, err
  2534. }
  2535. rules, err := getRulesWithActions(ctx, []EventRule{rule}, dbHandle)
  2536. if err != nil {
  2537. return rule, err
  2538. }
  2539. if len(rules) != 1 {
  2540. return rule, fmt.Errorf("unable to associate rule %q with actions", name)
  2541. }
  2542. return rules[0], nil
  2543. }
  2544. func sqlCommonDumpEventRules(dbHandle sqlQuerier) ([]EventRule, error) {
  2545. rules := make([]EventRule, 0, 10)
  2546. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2547. defer cancel()
  2548. q := getDumpEventRulesQuery()
  2549. rows, err := dbHandle.QueryContext(ctx, q)
  2550. if err != nil {
  2551. return rules, err
  2552. }
  2553. defer rows.Close()
  2554. for rows.Next() {
  2555. rule, err := getEventRuleFromDbRow(rows)
  2556. if err != nil {
  2557. return rules, err
  2558. }
  2559. rules = append(rules, rule)
  2560. }
  2561. err = rows.Err()
  2562. if err != nil {
  2563. return rules, err
  2564. }
  2565. return getRulesWithActions(ctx, rules, dbHandle)
  2566. }
  2567. func sqlCommonGetRecentlyUpdatedRules(after int64, dbHandle sqlQuerier) ([]EventRule, error) {
  2568. rules := make([]EventRule, 0, 10)
  2569. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2570. defer cancel()
  2571. q := getRecentlyUpdatedRulesQuery()
  2572. rows, err := dbHandle.QueryContext(ctx, q, after)
  2573. if err != nil {
  2574. return rules, err
  2575. }
  2576. defer rows.Close()
  2577. for rows.Next() {
  2578. rule, err := getEventRuleFromDbRow(rows)
  2579. if err != nil {
  2580. return rules, err
  2581. }
  2582. rules = append(rules, rule)
  2583. }
  2584. err = rows.Err()
  2585. if err != nil {
  2586. return rules, err
  2587. }
  2588. return getRulesWithActions(ctx, rules, dbHandle)
  2589. }
  2590. func sqlCommonGetEventRules(limit int, offset int, order string, dbHandle sqlQuerier) ([]EventRule, error) {
  2591. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2592. defer cancel()
  2593. q := getEventRulesQuery(order)
  2594. rules := make([]EventRule, 0, limit)
  2595. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2596. if err != nil {
  2597. return rules, err
  2598. }
  2599. defer rows.Close()
  2600. for rows.Next() {
  2601. rule, err := getEventRuleFromDbRow(rows)
  2602. if err != nil {
  2603. return rules, err
  2604. }
  2605. rules = append(rules, rule)
  2606. }
  2607. err = rows.Err()
  2608. if err != nil {
  2609. return rules, err
  2610. }
  2611. rules, err = getRulesWithActions(ctx, rules, dbHandle)
  2612. if err != nil {
  2613. return rules, err
  2614. }
  2615. for idx := range rules {
  2616. rules[idx].PrepareForRendering()
  2617. }
  2618. return rules, nil
  2619. }
  2620. func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
  2621. if err := rule.validate(); err != nil {
  2622. return err
  2623. }
  2624. conditions, err := json.Marshal(rule.Conditions)
  2625. if err != nil {
  2626. return err
  2627. }
  2628. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2629. defer cancel()
  2630. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2631. q := getAddEventRuleQuery()
  2632. _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  2633. util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions))
  2634. if err != nil {
  2635. return err
  2636. }
  2637. return generateEventRuleActionsMapping(ctx, rule, tx)
  2638. })
  2639. }
  2640. func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
  2641. if err := rule.validate(); err != nil {
  2642. return err
  2643. }
  2644. conditions, err := json.Marshal(rule.Conditions)
  2645. if err != nil {
  2646. return err
  2647. }
  2648. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2649. defer cancel()
  2650. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2651. q := getUpdateEventRuleQuery()
  2652. _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  2653. rule.Trigger, string(conditions), rule.Name)
  2654. if err != nil {
  2655. return err
  2656. }
  2657. return generateEventRuleActionsMapping(ctx, rule, tx)
  2658. })
  2659. }
  2660. func sqlCommonDeleteEventRule(rule EventRule, softDelete bool, dbHandle *sql.DB) error {
  2661. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2662. defer cancel()
  2663. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2664. if softDelete {
  2665. q := getClearRuleActionMappingQuery()
  2666. _, err := tx.ExecContext(ctx, q, rule.Name)
  2667. if err != nil {
  2668. return err
  2669. }
  2670. }
  2671. q := getDeleteEventRuleQuery(softDelete)
  2672. if softDelete {
  2673. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  2674. res, err := tx.ExecContext(ctx, q, ts, ts, rule.Name)
  2675. if err != nil {
  2676. return err
  2677. }
  2678. return sqlCommonRequireRowAffected(res)
  2679. }
  2680. res, err := tx.ExecContext(ctx, q, rule.Name)
  2681. if err != nil {
  2682. return err
  2683. }
  2684. if err = sqlCommonRequireRowAffected(res); err != nil {
  2685. return err
  2686. }
  2687. return sqlCommonDeleteTask(rule.Name, tx)
  2688. })
  2689. }
  2690. func sqlCommonGetTaskByName(name string, dbHandle sqlQuerier) (Task, error) {
  2691. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2692. defer cancel()
  2693. task := Task{
  2694. Name: name,
  2695. }
  2696. q := getTaskByNameQuery()
  2697. row := dbHandle.QueryRowContext(ctx, q, name)
  2698. err := row.Scan(&task.UpdateAt, &task.Version)
  2699. if err != nil {
  2700. if errors.Is(err, sql.ErrNoRows) {
  2701. return task, util.NewRecordNotFoundError(err.Error())
  2702. }
  2703. }
  2704. return task, err
  2705. }
  2706. func sqlCommonAddTask(name string, dbHandle *sql.DB) error {
  2707. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2708. defer cancel()
  2709. q := getAddTaskQuery()
  2710. _, err := dbHandle.ExecContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now()))
  2711. return err
  2712. }
  2713. func sqlCommonUpdateTask(name string, version int64, dbHandle *sql.DB) error {
  2714. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2715. defer cancel()
  2716. q := getUpdateTaskQuery()
  2717. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name, version)
  2718. if err != nil {
  2719. return err
  2720. }
  2721. return sqlCommonRequireRowAffected(res)
  2722. }
  2723. func sqlCommonUpdateTaskTimestamp(name string, dbHandle *sql.DB) error {
  2724. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2725. defer cancel()
  2726. q := getUpdateTaskTimestampQuery()
  2727. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  2728. if err != nil {
  2729. return err
  2730. }
  2731. return sqlCommonRequireRowAffected(res)
  2732. }
  2733. func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error {
  2734. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2735. defer cancel()
  2736. q := getDeleteTaskQuery()
  2737. _, err := dbHandle.ExecContext(ctx, q, name)
  2738. return err
  2739. }
  2740. func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
  2741. var result schemaVersion
  2742. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2743. defer cancel()
  2744. q := getDatabaseVersionQuery()
  2745. stmt, err := dbHandle.PrepareContext(ctx, q)
  2746. if err != nil {
  2747. providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
  2748. if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
  2749. logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
  2750. }
  2751. return result, err
  2752. }
  2753. defer stmt.Close()
  2754. row := stmt.QueryRowContext(ctx)
  2755. err = row.Scan(&result.Version)
  2756. return result, err
  2757. }
  2758. func sqlCommonRequireRowAffected(res sql.Result) error {
  2759. // MariaDB/MySQL returns 0 rows affected for updates that don't change anything
  2760. // so we don't check rows affected for updates
  2761. affected, err := res.RowsAffected()
  2762. if err == nil && affected == 0 {
  2763. return util.NewRecordNotFoundError(sql.ErrNoRows.Error())
  2764. }
  2765. return nil
  2766. }
  2767. func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
  2768. q := getUpdateDBVersionQuery()
  2769. _, err := dbHandle.ExecContext(ctx, q, version)
  2770. return err
  2771. }
  2772. func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error {
  2773. if err := sqlAcquireLock(dbHandle); err != nil {
  2774. return err
  2775. }
  2776. defer sqlReleaseLock(dbHandle)
  2777. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2778. defer cancel()
  2779. if newVersion > 0 {
  2780. currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false)
  2781. if err == nil {
  2782. if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) {
  2783. providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
  2784. currentVersion.Version, newVersion)
  2785. return nil
  2786. }
  2787. }
  2788. }
  2789. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2790. for _, q := range sqlQueries {
  2791. if strings.TrimSpace(q) == "" {
  2792. continue
  2793. }
  2794. _, err := tx.ExecContext(ctx, q)
  2795. if err != nil {
  2796. return err
  2797. }
  2798. }
  2799. if newVersion == 0 {
  2800. return nil
  2801. }
  2802. return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
  2803. })
  2804. }
  2805. func sqlAcquireLock(dbHandle *sql.DB) error {
  2806. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2807. defer cancel()
  2808. switch config.Driver {
  2809. case PGSQLDataProviderName:
  2810. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`)
  2811. if err != nil {
  2812. return fmt.Errorf("unable to get advisory lock: %w", err)
  2813. }
  2814. providerLog(logger.LevelInfo, "acquired database lock")
  2815. case MySQLDataProviderName:
  2816. var lockResult sql.NullInt64
  2817. err := dbHandle.QueryRowContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`).Scan(&lockResult)
  2818. if err != nil {
  2819. return fmt.Errorf("unable to get lock: %w", err)
  2820. }
  2821. if !lockResult.Valid {
  2822. return errors.New("unable to get lock: null value returned")
  2823. }
  2824. if lockResult.Int64 != 1 {
  2825. return fmt.Errorf("unable to get lock, result: %d", lockResult.Int64)
  2826. }
  2827. providerLog(logger.LevelInfo, "acquired database lock")
  2828. }
  2829. return nil
  2830. }
  2831. func sqlReleaseLock(dbHandle *sql.DB) {
  2832. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2833. defer cancel()
  2834. switch config.Driver {
  2835. case PGSQLDataProviderName:
  2836. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`)
  2837. if err != nil {
  2838. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  2839. } else {
  2840. providerLog(logger.LevelInfo, "released database lock")
  2841. }
  2842. case MySQLDataProviderName:
  2843. _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`)
  2844. if err != nil {
  2845. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  2846. } else {
  2847. providerLog(logger.LevelInfo, "released database lock")
  2848. }
  2849. }
  2850. }
  2851. func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
  2852. if config.Driver == CockroachDataProviderName {
  2853. return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
  2854. }
  2855. tx, err := dbHandle.BeginTx(ctx, nil)
  2856. if err != nil {
  2857. return err
  2858. }
  2859. err = txFn(tx)
  2860. if err != nil {
  2861. // we don't change the returned error
  2862. tx.Rollback() //nolint:errcheck
  2863. return err
  2864. }
  2865. return tx.Commit()
  2866. }