transferschecker.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "errors"
  17. "sync"
  18. "time"
  19. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  20. "github.com/drakkan/sftpgo/v2/internal/logger"
  21. "github.com/drakkan/sftpgo/v2/internal/util"
  22. )
  23. type overquotaTransfer struct {
  24. ConnID string
  25. TransferID int64
  26. TransferType int
  27. }
  28. type uploadAggregationKey struct {
  29. Username string
  30. FolderName string
  31. }
  32. // TransfersChecker defines the interface that transfer checkers must implement.
  33. // A transfer checker ensure that multiple concurrent transfers does not exceeded
  34. // the remaining user quota
  35. type TransfersChecker interface {
  36. AddTransfer(transfer dataprovider.ActiveTransfer)
  37. RemoveTransfer(ID int64, connectionID string)
  38. UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string)
  39. GetOverquotaTransfers() []overquotaTransfer
  40. }
  41. func getTransfersChecker(isShared int) TransfersChecker {
  42. if isShared == 1 {
  43. logger.Info(logSender, "", "using provider transfer checker")
  44. return &transfersCheckerDB{}
  45. }
  46. logger.Info(logSender, "", "using memory transfer checker")
  47. return &transfersCheckerMem{}
  48. }
  49. type baseTransferChecker struct {
  50. transfers []dataprovider.ActiveTransfer
  51. }
  52. func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize,
  53. dlSize int64,
  54. ) bool {
  55. ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits(transfer.IP)
  56. if totalQuota > 0 {
  57. allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer)
  58. if ulSize+dlSize > allowedSize {
  59. return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0
  60. }
  61. }
  62. if dlQuota > 0 {
  63. allowedSize := dlQuota - user.UsedDownloadDataTransfer
  64. if dlSize > allowedSize {
  65. return transfer.CurrentDLSize > 0
  66. }
  67. }
  68. if ulQuota > 0 {
  69. allowedSize := ulQuota - user.UsedUploadDataTransfer
  70. if ulSize > allowedSize {
  71. return transfer.CurrentULSize > 0
  72. }
  73. }
  74. return false
  75. }
  76. func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
  77. var result int64
  78. if folderName != "" {
  79. for _, folder := range user.VirtualFolders {
  80. if folder.Name == folderName {
  81. if folder.QuotaSize > 0 {
  82. return folder.QuotaSize - folder.UsedQuotaSize, nil
  83. }
  84. }
  85. }
  86. } else {
  87. if user.QuotaSize > 0 {
  88. return user.QuotaSize - user.UsedQuotaSize, nil
  89. }
  90. }
  91. return result, errors.New("no quota limit defined")
  92. }
  93. func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool,
  94. ) (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
  95. aggregations := make(map[string][]dataprovider.ActiveTransfer)
  96. for _, transfer := range t.transfers {
  97. aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer)
  98. if len(aggregations[transfer.Username]) > 1 {
  99. if _, ok := usersToFetch[transfer.Username]; !ok {
  100. usersToFetch[transfer.Username] = false
  101. }
  102. }
  103. }
  104. return usersToFetch, aggregations
  105. }
  106. func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) {
  107. usersToFetch := make(map[string]bool)
  108. aggregations := make(map[int][]dataprovider.ActiveTransfer)
  109. var keys []uploadAggregationKey
  110. for _, transfer := range t.transfers {
  111. if transfer.Type != TransferUpload {
  112. continue
  113. }
  114. key := -1
  115. for idx, k := range keys {
  116. if k.Username == transfer.Username && k.FolderName == transfer.FolderName {
  117. key = idx
  118. break
  119. }
  120. }
  121. if key == -1 {
  122. key = len(keys)
  123. }
  124. keys = append(keys, uploadAggregationKey{
  125. Username: transfer.Username,
  126. FolderName: transfer.FolderName,
  127. })
  128. aggregations[key] = append(aggregations[key], transfer)
  129. if len(aggregations[key]) > 1 {
  130. if transfer.FolderName != "" {
  131. usersToFetch[transfer.Username] = true
  132. } else {
  133. if _, ok := usersToFetch[transfer.Username]; !ok {
  134. usersToFetch[transfer.Username] = false
  135. }
  136. }
  137. }
  138. }
  139. return usersToFetch, aggregations
  140. }
  141. func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) {
  142. users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
  143. if err != nil {
  144. return nil, err
  145. }
  146. usersMap := make(map[string]dataprovider.User)
  147. for _, user := range users {
  148. usersMap[user.Username] = user
  149. }
  150. return usersMap, nil
  151. }
  152. func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool,
  153. uploadAggregations map[int][]dataprovider.ActiveTransfer,
  154. userAggregations map[string][]dataprovider.ActiveTransfer,
  155. ) []overquotaTransfer {
  156. if len(usersToFetch) == 0 {
  157. return nil
  158. }
  159. usersMap, err := t.getUsersToCheck(usersToFetch)
  160. if err != nil {
  161. logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
  162. return nil
  163. }
  164. var overquotaTransfers []overquotaTransfer
  165. for _, transfers := range uploadAggregations {
  166. username := transfers[0].Username
  167. folderName := transfers[0].FolderName
  168. remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
  169. if err != nil {
  170. continue
  171. }
  172. var usedDiskQuota int64
  173. for _, tr := range transfers {
  174. // We optimistically assume that a cloud transfer that replaces an existing
  175. // file will be successful
  176. usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
  177. }
  178. logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v",
  179. username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
  180. if usedDiskQuota > remaningDiskQuota {
  181. for _, tr := range transfers {
  182. if tr.CurrentULSize > tr.TruncatedSize {
  183. overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
  184. ConnID: tr.ConnID,
  185. TransferID: tr.ID,
  186. TransferType: tr.Type,
  187. })
  188. }
  189. }
  190. }
  191. }
  192. for username, transfers := range userAggregations {
  193. var ulSize, dlSize int64
  194. for _, tr := range transfers {
  195. ulSize += tr.CurrentULSize
  196. dlSize += tr.CurrentDLSize
  197. }
  198. logger.Debug(logSender, "", "username %#v, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v",
  199. username, len(transfers), ulSize, dlSize)
  200. for _, tr := range transfers {
  201. if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) {
  202. overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
  203. ConnID: tr.ConnID,
  204. TransferID: tr.ID,
  205. TransferType: tr.Type,
  206. })
  207. }
  208. }
  209. }
  210. return overquotaTransfers
  211. }
  212. type transfersCheckerMem struct {
  213. sync.RWMutex
  214. baseTransferChecker
  215. }
  216. func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
  217. t.Lock()
  218. defer t.Unlock()
  219. t.transfers = append(t.transfers, transfer)
  220. }
  221. func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
  222. t.Lock()
  223. defer t.Unlock()
  224. for idx, transfer := range t.transfers {
  225. if transfer.ID == ID && transfer.ConnID == connectionID {
  226. lastIdx := len(t.transfers) - 1
  227. t.transfers[idx] = t.transfers[lastIdx]
  228. t.transfers = t.transfers[:lastIdx]
  229. return
  230. }
  231. }
  232. }
  233. func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
  234. t.Lock()
  235. defer t.Unlock()
  236. for idx := range t.transfers {
  237. if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
  238. t.transfers[idx].CurrentDLSize = dlSize
  239. t.transfers[idx].CurrentULSize = ulSize
  240. t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  241. return
  242. }
  243. }
  244. }
  245. func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
  246. t.RLock()
  247. usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
  248. usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
  249. t.RUnlock()
  250. return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
  251. }
  252. type transfersCheckerDB struct {
  253. baseTransferChecker
  254. lastCleanup time.Time
  255. }
  256. func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) {
  257. dataprovider.AddActiveTransfer(transfer)
  258. }
  259. func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) {
  260. dataprovider.RemoveActiveTransfer(ID, connectionID)
  261. }
  262. func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
  263. dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID)
  264. }
  265. func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer {
  266. if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) {
  267. before := time.Now().Add(-periodicTimeoutCheckInterval * 5)
  268. err := dataprovider.CleanupActiveTransfers(before)
  269. logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err)
  270. if err == nil {
  271. t.lastCleanup = time.Now()
  272. }
  273. }
  274. var err error
  275. from := time.Now().Add(-periodicTimeoutCheckInterval * 2)
  276. t.transfers, err = dataprovider.GetActiveTransfers(from)
  277. if err != nil {
  278. logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err)
  279. return nil
  280. }
  281. usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
  282. usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
  283. return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
  284. }