Browse Source

scp: fix quota update after file overwrite

added a test case too
Nicola Murino 6 years ago
parent
commit
dc5eeb54fd
4 changed files with 17 additions and 8 deletions
  1. 2 2
      README.md
  2. 4 5
      sftpd/scp.go
  3. 10 0
      sftpd/sftpd_test.go
  4. 1 1
      utils/version.go

+ 2 - 2
README.md

@@ -57,7 +57,7 @@ Version info, such as git commit and build date, can be embedded setting the fol
 For example you can build using the following command:
 
 ```bash
-go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo
+go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo
 ```
 
 and you will get a version that includes git commit and build date like this one:
@@ -71,7 +71,7 @@ For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/mas
 
 Alternately you can use distro packages:
 
-- Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/packages/sftpgo-git/ "SFTPGo")
+- Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/packages/sftpgo/ "SFTPGo")
 
 ## Configuration
 

+ 4 - 5
sftpd/scp.go

@@ -188,8 +188,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
 	return c.sendConfirmationMessage()
 }
 
-func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64) error {
-	logger.Debug(logSenderSCP, "upload to new file: %v", filePath)
+func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error {
 	if !c.connection.hasSpace(true) {
 		err := fmt.Errorf("denying file write due to space limit")
 		logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", filePath, err)
@@ -225,7 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
 		connectionID:  c.connection.ID,
 		transferType:  transferUpload,
 		lastActivity:  time.Now(),
-		isNewFile:     true,
+		isNewFile:     isNewFile,
 		protocol:      c.connection.protocol,
 	}
 	addTransfer(&transfer)
@@ -256,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
 	}
 	stat, statErr := os.Stat(p)
 	if os.IsNotExist(statErr) {
-		return c.handleUploadFile(p, filePath, sizeToRead)
+		return c.handleUploadFile(p, filePath, sizeToRead, true)
 	}
 
 	if statErr != nil {
@@ -284,7 +283,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
 
 	dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false)
 
-	return c.handleUploadFile(p, filePath, sizeToRead)
+	return c.handleUploadFile(p, filePath, sizeToRead, false)
 }
 
 func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {

+ 10 - 0
sftpd/sftpd_test.go

@@ -1503,10 +1503,12 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
 	}
 	usePubKey := true
 	u := getTestUser(usePubKey)
+	u.QuotaFiles = 1000
 	user, _, err := api.AddUser(u, http.StatusOK)
 	if err != nil {
 		t.Errorf("unable to add user: %v", err)
 	}
+	os.RemoveAll(user.GetHomeDir())
 	testFileName := "test_file.dat"
 	testFilePath := filepath.Join(homeBasePath, testFileName)
 	testFileSize := int64(32760)
@@ -1524,6 +1526,14 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
 	if err != nil {
 		t.Errorf("error uploading existing file via scp: %v", err)
 	}
+	user, _, err = api.GetUserByID(user.ID, http.StatusOK)
+	if err != nil {
+		t.Errorf("error getting user: %v", err)
+	}
+	if user.UsedQuotaSize != testFileSize || user.UsedQuotaFiles != 1 {
+		t.Errorf("update quota error on file overwrite, actual size: %v, expected: %v actual files: %v, expected: 1",
+			user.UsedQuotaSize, testFileSize, user.UsedQuotaFiles)
+	}
 	remoteDownPath := fmt.Sprintf("%[email protected]:%v", user.Username, path.Join("/", testFileName))
 	localPath := filepath.Join(homeBasePath, "scp_download.dat")
 	err = scpDownload(localPath, remoteDownPath, false, false)

+ 1 - 1
utils/version.go

@@ -1,6 +1,6 @@
 package utils
 
-const version = "0.9.1"
+const version = "0.9.1-dev"
 
 var (
 	commit      = ""