transferschecker.go 9.0 KB


  1. package common
  2. import (
  3. "errors"
  4. "sync"
  5. "time"
  6. "github.com/drakkan/sftpgo/v2/dataprovider"
  7. "github.com/drakkan/sftpgo/v2/logger"
  8. "github.com/drakkan/sftpgo/v2/util"
  9. )
  10. type overquotaTransfer struct {
  11. ConnID string
  12. TransferID int64
  13. TransferType int
  14. }
  15. type uploadAggregationKey struct {
  16. Username string
  17. FolderName string
  18. }
  19. // TransfersChecker defines the interface that transfer checkers must implement.
  20. // A transfer checker ensure that multiple concurrent transfers does not exceeded
  21. // the remaining user quota
  22. type TransfersChecker interface {
  23. AddTransfer(transfer dataprovider.ActiveTransfer)
  24. RemoveTransfer(ID int64, connectionID string)
  25. UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string)
  26. GetOverquotaTransfers() []overquotaTransfer
  27. }
  28. func getTransfersChecker(isShared int) TransfersChecker {
  29. if isShared == 1 {
  30. logger.Info(logSender, "", "using provider transfer checker")
  31. return &transfersCheckerDB{}
  32. }
  33. logger.Info(logSender, "", "using memory transfer checker")
  34. return &transfersCheckerMem{}
  35. }
  36. type baseTransferChecker struct {
  37. transfers []dataprovider.ActiveTransfer
  38. }
  39. func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize,
  40. dlSize int64,
  41. ) bool {
  42. ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits(transfer.IP)
  43. if totalQuota > 0 {
  44. allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer)
  45. if ulSize+dlSize > allowedSize {
  46. return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0
  47. }
  48. }
  49. if dlQuota > 0 {
  50. allowedSize := dlQuota - user.UsedDownloadDataTransfer
  51. if dlSize > allowedSize {
  52. return transfer.CurrentDLSize > 0
  53. }
  54. }
  55. if ulQuota > 0 {
  56. allowedSize := ulQuota - user.UsedUploadDataTransfer
  57. if ulSize > allowedSize {
  58. return transfer.CurrentULSize > 0
  59. }
  60. }
  61. return false
  62. }
  63. func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
  64. var result int64
  65. if folderName != "" {
  66. for _, folder := range user.VirtualFolders {
  67. if folder.Name == folderName {
  68. if folder.QuotaSize > 0 {
  69. return folder.QuotaSize - folder.UsedQuotaSize, nil
  70. }
  71. }
  72. }
  73. } else {
  74. if user.QuotaSize > 0 {
  75. return user.QuotaSize - user.UsedQuotaSize, nil
  76. }
  77. }
  78. return result, errors.New("no quota limit defined")
  79. }
  80. func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool,
  81. ) (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
  82. aggregations := make(map[string][]dataprovider.ActiveTransfer)
  83. for _, transfer := range t.transfers {
  84. aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer)
  85. if len(aggregations[transfer.Username]) > 1 {
  86. if _, ok := usersToFetch[transfer.Username]; !ok {
  87. usersToFetch[transfer.Username] = false
  88. }
  89. }
  90. }
  91. return usersToFetch, aggregations
  92. }
  93. func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) {
  94. usersToFetch := make(map[string]bool)
  95. aggregations := make(map[int][]dataprovider.ActiveTransfer)
  96. var keys []uploadAggregationKey
  97. for _, transfer := range t.transfers {
  98. if transfer.Type != TransferUpload {
  99. continue
  100. }
  101. key := -1
  102. for idx, k := range keys {
  103. if k.Username == transfer.Username && k.FolderName == transfer.FolderName {
  104. key = idx
  105. break
  106. }
  107. }
  108. if key == -1 {
  109. key = len(keys)
  110. }
  111. keys = append(keys, uploadAggregationKey{
  112. Username: transfer.Username,
  113. FolderName: transfer.FolderName,
  114. })
  115. aggregations[key] = append(aggregations[key], transfer)
  116. if len(aggregations[key]) > 1 {
  117. if transfer.FolderName != "" {
  118. usersToFetch[transfer.Username] = true
  119. } else {
  120. if _, ok := usersToFetch[transfer.Username]; !ok {
  121. usersToFetch[transfer.Username] = false
  122. }
  123. }
  124. }
  125. }
  126. return usersToFetch, aggregations
  127. }
  128. func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) {
  129. users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
  130. if err != nil {
  131. return nil, err
  132. }
  133. usersMap := make(map[string]dataprovider.User)
  134. for _, user := range users {
  135. usersMap[user.Username] = user
  136. }
  137. return usersMap, nil
  138. }
  139. func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool,
  140. uploadAggregations map[int][]dataprovider.ActiveTransfer,
  141. userAggregations map[string][]dataprovider.ActiveTransfer,
  142. ) []overquotaTransfer {
  143. if len(usersToFetch) == 0 {
  144. return nil
  145. }
  146. usersMap, err := t.getUsersToCheck(usersToFetch)
  147. if err != nil {
  148. logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
  149. return nil
  150. }
  151. var overquotaTransfers []overquotaTransfer
  152. for _, transfers := range uploadAggregations {
  153. username := transfers[0].Username
  154. folderName := transfers[0].FolderName
  155. remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
  156. if err != nil {
  157. continue
  158. }
  159. var usedDiskQuota int64
  160. for _, tr := range transfers {
  161. // We optimistically assume that a cloud transfer that replaces an existing
  162. // file will be successful
  163. usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
  164. }
  165. logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v",
  166. username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
  167. if usedDiskQuota > remaningDiskQuota {
  168. for _, tr := range transfers {
  169. if tr.CurrentULSize > tr.TruncatedSize {
  170. overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
  171. ConnID: tr.ConnID,
  172. TransferID: tr.ID,
  173. TransferType: tr.Type,
  174. })
  175. }
  176. }
  177. }
  178. }
  179. for username, transfers := range userAggregations {
  180. var ulSize, dlSize int64
  181. for _, tr := range transfers {
  182. ulSize += tr.CurrentULSize
  183. dlSize += tr.CurrentDLSize
  184. }
  185. logger.Debug(logSender, "", "username %#v, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v",
  186. username, len(transfers), ulSize, dlSize)
  187. for _, tr := range transfers {
  188. if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) {
  189. overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
  190. ConnID: tr.ConnID,
  191. TransferID: tr.ID,
  192. TransferType: tr.Type,
  193. })
  194. }
  195. }
  196. }
  197. return overquotaTransfers
  198. }
  199. type transfersCheckerMem struct {
  200. sync.RWMutex
  201. baseTransferChecker
  202. }
  203. func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
  204. t.Lock()
  205. defer t.Unlock()
  206. t.transfers = append(t.transfers, transfer)
  207. }
  208. func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
  209. t.Lock()
  210. defer t.Unlock()
  211. for idx, transfer := range t.transfers {
  212. if transfer.ID == ID && transfer.ConnID == connectionID {
  213. lastIdx := len(t.transfers) - 1
  214. t.transfers[idx] = t.transfers[lastIdx]
  215. t.transfers = t.transfers[:lastIdx]
  216. return
  217. }
  218. }
  219. }
  220. func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
  221. t.Lock()
  222. defer t.Unlock()
  223. for idx := range t.transfers {
  224. if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
  225. t.transfers[idx].CurrentDLSize = dlSize
  226. t.transfers[idx].CurrentULSize = ulSize
  227. t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  228. return
  229. }
  230. }
  231. }
  232. func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
  233. t.RLock()
  234. usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
  235. usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
  236. t.RUnlock()
  237. return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
  238. }
  239. type transfersCheckerDB struct {
  240. baseTransferChecker
  241. lastCleanup time.Time
  242. }
  243. func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) {
  244. dataprovider.AddActiveTransfer(transfer)
  245. }
  246. func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) {
  247. dataprovider.RemoveActiveTransfer(ID, connectionID)
  248. }
  249. func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
  250. dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID)
  251. }
  252. func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer {
  253. if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) {
  254. before := time.Now().Add(-periodicTimeoutCheckInterval * 5)
  255. err := dataprovider.CleanupActiveTransfers(before)
  256. logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err)
  257. if err == nil {
  258. t.lastCleanup = time.Now()
  259. }
  260. }
  261. var err error
  262. from := time.Now().Add(-periodicTimeoutCheckInterval * 2)
  263. t.transfers, err = dataprovider.GetActiveTransfers(from)
  264. if err != nil {
  265. logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err)
  266. return nil
  267. }
  268. usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
  269. usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
  270. return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
  271. }