Browse Source

quota: move user and folder management to a common method

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 1 year ago
parent
commit
92849ca473

+ 16 - 43
internal/common/connection.go

@@ -449,10 +449,7 @@ func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info
 	if updateQuota && info.Mode()&os.ModeSymlink == 0 {
 	if updateQuota && info.Mode()&os.ModeSymlink == 0 {
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
 		if err == nil {
 		if err == nil {
-			dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
-			if vfolder.IsIncludedInUserQuota() {
-				dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
-			}
+			dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, -1, -size, false)
 		} else {
 		} else {
 			dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
 			dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
 		}
 		}
@@ -1121,10 +1118,7 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
 		sizeDiff := initialSize - size
 		sizeDiff := initialSize - size
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
 		if err == nil {
 		if err == nil {
-			dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
-			if vfolder.IsIncludedInUserQuota() {
-				dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
-			}
+			dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -sizeDiff, false)
 		} else {
 		} else {
 			dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
 			dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
 		}
 		}
@@ -1518,61 +1512,40 @@ func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder
 	if sourceFolder.Name == dstFolder.Name {
 	if sourceFolder.Name == dstFolder.Name {
 		// both files are inside the same virtual folder
 		// both files are inside the same virtual folder
 		if initialSize != -1 {
 		if initialSize != -1 {
-			dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
-			if dstFolder.IsIncludedInUserQuota() {
-				dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck
-			}
+			dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, -numFiles, -initialSize, false)
 		}
 		}
 		return
 		return
 	}
 	}
 	// files are inside different virtual folders
 	// files are inside different virtual folders
-	dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
-	if sourceFolder.IsIncludedInUserQuota() {
-		dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
-	}
+	dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
 	if initialSize == -1 {
 	if initialSize == -1 {
-		dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
-		if dstFolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
-		}
-	} else {
-		// we cannot have a directory here, initialSize != -1 only for files
-		dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
-		if dstFolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
-		}
+		dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
+		return
 	}
 	}
+	// we cannot have a directory here, initialSize != -1 only for files
+	dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
 }
 }
 
 
 func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
 func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
 	// move between a virtual folder and the user home dir
 	// move between a virtual folder and the user home dir
-	dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
-	if sourceFolder.IsIncludedInUserQuota() {
-		dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
-	}
+	dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false)
 	if initialSize == -1 {
 	if initialSize == -1 {
 		dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
 		dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
-	} else {
-		// we cannot have a directory here, initialSize != -1 only for files
-		dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
+		return
 	}
 	}
+	// we cannot have a directory here, initialSize != -1 only for files
+	dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
 }
 }
 
 
 func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
 func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
 	// move between the user home dir and a virtual folder
 	// move between the user home dir and a virtual folder
 	dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
 	dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
 	if initialSize == -1 {
 	if initialSize == -1 {
-		dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
-		if dstFolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
-		}
-	} else {
-		// we cannot have a directory here, initialSize != -1 only for files
-		dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
-		if dstFolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
-		}
+		dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false)
+		return
 	}
 	}
+	// we cannot have a directory here, initialSize != -1 only for files
+	dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false)
 }
 }
 
 
 func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string,
 func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string,

+ 1 - 4
internal/common/eventmanager.go

@@ -908,10 +908,7 @@ func updateUserQuotaAfterFileWrite(conn *BaseConnection, virtualPath string, num
 		dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
 		dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
 		return
 		return
 	}
 	}
-	dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, fileSize, false) //nolint:errcheck
-	if vfolder.IsIncludedInUserQuota() {
-		dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck
-	}
+	dataprovider.UpdateUserFolderQuota(&vfolder, &conn.User, numFiles, fileSize, false)
 }
 }
 
 
 func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error {
 func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error {

+ 1 - 4
internal/common/transfer.go

@@ -521,11 +521,8 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
 	if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) {
 	if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) {
 		vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
 		vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
 		if err == nil {
 		if err == nil {
-			dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
+			dataprovider.UpdateUserFolderQuota(&vfolder, &t.Connection.User, numFiles,
 				sizeDiff, false)
 				sizeDiff, false)
-			if vfolder.IsIncludedInUserQuota() {
-				dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
-			}
 		} else {
 		} else {
 			dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
 			dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
 		}
 		}

+ 8 - 0
internal/dataprovider/dataprovider.go

@@ -1515,6 +1515,14 @@ func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error
 	return nil
 	return nil
 }
 }
 
 
+// UpdateUserFolderQuota updates the quota for the given user and virtual folder.
+func UpdateUserFolderQuota(folder *vfs.VirtualFolder, user *User, filesAdd int, sizeAdd int64, reset bool) {
+	UpdateVirtualFolderQuota(&folder.BaseVirtualFolder, filesAdd, sizeAdd, reset) //nolint:errcheck
+	if folder.IsIncludedInUserQuota() {
+		UpdateUserQuota(user, filesAdd, sizeAdd, reset) //nolint:errcheck
+	}
+}
+
 // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
 // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
 // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
 // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
 func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
 func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {

+ 1 - 4
internal/ftpd/handler.go

@@ -493,10 +493,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
 		if vfs.HasTruncateSupport(fs) {
 		if vfs.HasTruncateSupport(fs) {
 			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			if err == nil {
 			if err == nil {
-				dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
-				if vfolder.IsIncludedInUserQuota() {
-					dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
-				}
+				dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
 			} else {
 			} else {
 				dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 				dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 			}
 			}

+ 1 - 4
internal/httpd/handler.go

@@ -213,10 +213,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
 		if vfs.HasTruncateSupport(fs) {
 		if vfs.HasTruncateSupport(fs) {
 			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			if err == nil {
 			if err == nil {
-				dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
-				if vfolder.IsIncludedInUserQuota() {
-					dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
-				}
+				dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
 			} else {
 			} else {
 				dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 				dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 			}
 			}

+ 3 - 6
internal/sftpd/handler.go

@@ -560,13 +560,10 @@ func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResu
 func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) {
 func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) {
 	vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 	vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 	if err == nil {
 	if err == nil {
-		dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
-		if vfolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
-		}
-	} else {
-		dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
+		dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
+		return
 	}
 	}
+	dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 }
 }
 
 
 func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {
 func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {

+ 1 - 4
internal/sftpd/scp.go

@@ -258,10 +258,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
 		if vfs.HasTruncateSupport(fs) {
 		if vfs.HasTruncateSupport(fs) {
 			vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
 			if err == nil {
 			if err == nil {
-				dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
-				if vfolder.IsIncludedInUserQuota() {
-					dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
-				}
+				dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, 0, -fileSize, false)
 			} else {
 			} else {
 				dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
 				dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
 			}
 			}

+ 3 - 6
internal/sftpd/ssh_cmd.go

@@ -192,13 +192,10 @@ func (c *sshCommand) handleSFTPGoRemove() error {
 func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
 func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
 	vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
 	vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
 	if err == nil {
 	if err == nil {
-		dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
-		if vfolder.IsIncludedInUserQuota() {
-			dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
-		}
-	} else {
-		dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
+		dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, filesNum, filesSize, false)
+		return
 	}
 	}
+	dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
 }
 }
 
 
 func (c *sshCommand) handleHashCommands() error {
 func (c *sshCommand) handleHashCommands() error {

+ 1 - 4
internal/webdavd/handler.go

@@ -272,10 +272,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
 	if vfs.HasTruncateSupport(fs) {
 	if vfs.HasTruncateSupport(fs) {
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 		if err == nil {
 		if err == nil {
-			dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
-			if vfolder.IsIncludedInUserQuota() {
-				dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
-			}
+			dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false)
 		} else {
 		} else {
 			dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 			dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
 		}
 		}