浏览代码

S3: fix quota update after an upload error

S3 uploads are atomic, if the upload fails we have no partial file so we
have to update the user quota only if the upload succeed
Nicola Murino 5 年之前
父节点
当前提交
d481294519
共有 5 个文件被更改,包括 74 次插入13 次删除
  1. 7 1
      sftpd/handler.go
  2. 37 0
      sftpd/internal_test.go
  3. 12 5
      sftpd/scp.go
  4. 4 4
      sftpd/sftpd_test.go
  5. 14 3
      sftpd/transfer.go

+ 7 - 1
sftpd/handler.go

@@ -479,11 +479,16 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
 		return nil, vfs.GetSFTPError(c.fs, err)
 		return nil, vfs.GetSFTPError(c.fs, err)
 	}
 	}
 
 
+	initialSize := int64(0)
 	if pflags.Append && osFlags&os.O_TRUNC == 0 {
 	if pflags.Append && osFlags&os.O_TRUNC == 0 {
 		c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize)
 		c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize)
 		minWriteOffset = fileSize
 		minWriteOffset = fileSize
 	} else {
 	} else {
-		dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false)
+		if vfs.IsLocalOsFs(c.fs) {
+			dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false)
+		} else {
+			initialSize = fileSize
+		}
 	}
 	}
 
 
 	vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID())
 	vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID())
@@ -506,6 +511,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
 		transferError:  nil,
 		transferError:  nil,
 		isFinished:     false,
 		isFinished:     false,
 		minWriteOffset: minWriteOffset,
 		minWriteOffset: minWriteOffset,
+		initialSize:    initialSize,
 		lock:           new(sync.Mutex),
 		lock:           new(sync.Mutex),
 	}
 	}
 	addTransfer(&transfer)
 	addTransfer(&transfer)

+ 37 - 0
sftpd/internal_test.go

@@ -425,6 +425,28 @@ func TestUploadFiles(t *testing.T) {
 	if err == nil {
 	if err == nil {
 		t.Errorf("upload new file in missing path must fail")
 		t.Errorf("upload new file in missing path must fail")
 	}
 	}
+	c.fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
+	f, _ := ioutil.TempFile("", "temp")
+	f.Close()
+	_, err = c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123)
+	if err != nil {
+		t.Errorf("unexpected error: %v", err)
+	}
+	if len(activeTransfers) != 1 {
+		t.Errorf("unexpected number of transfer, expected 1, current: %v", len(activeTransfers))
+	}
+	transfer := activeTransfers[0]
+	if transfer.initialSize != 123 {
+		t.Errorf("unexpected initial size: %v", transfer.initialSize)
+	}
+	err = transfer.Close()
+	if err != nil {
+		t.Errorf("unexpected error: %v", err)
+	}
+	if len(activeTransfers) != 0 {
+		t.Errorf("unexpected number of transfer, expected 0, current: %v", len(activeTransfers))
+	}
+	os.Remove(f.Name())
 	uploadMode = oldUploadMode
 	uploadMode = oldUploadMode
 }
 }
 
 
@@ -899,6 +921,17 @@ func TestSystemCommandErrors(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestTransferUpdateQuota(t *testing.T) {
+	transfer := Transfer{
+		transferType:  transferUpload,
+		bytesReceived: 123,
+		lock:          new(sync.Mutex)}
+	transfer.TransferError(errors.New("fake error"))
+	if transfer.updateQuota(1) {
+		t.Errorf("update quota must fail, there is a error and this is a remote upload")
+	}
+}
+
 func TestGetConnectionInfo(t *testing.T) {
 func TestGetConnectionInfo(t *testing.T) {
 	c := ConnectionStatus{
 	c := ConnectionStatus{
 		Username:      "test_user",
 		Username:      "test_user",
@@ -1222,6 +1255,10 @@ func TestSCPErrorsMockFs(t *testing.T) {
 	if err != errFake {
 	if err != errFake {
 		t.Errorf("unexpected error: %v", err)
 		t.Errorf("unexpected error: %v", err)
 	}
 	}
+	err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4)
+	if err != nil {
+		t.Errorf("unexpected error: %v", err)
+	}
 	os.Remove(testfile)
 	os.Remove(testfile)
 }
 }
 
 

+ 12 - 5
sftpd/scp.go

@@ -181,7 +181,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
 	return c.sendConfirmationMessage()
 	return c.sendConfirmationMessage()
 }
 }
 
 
-func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error {
+func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64) error {
 	if !c.connection.hasSpace(true) {
 	if !c.connection.hasSpace(true) {
 		err := fmt.Errorf("denying file write due to space limit")
 		err := fmt.Errorf("denying file write due to space limit")
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err)
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err)
@@ -189,6 +189,14 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
 		return err
 		return err
 	}
 	}
 
 
+	initialSize := int64(0)
+	if !isNewFile {
+		if vfs.IsLocalOsFs(c.connection.fs) {
+			dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false)
+		} else {
+			initialSize = fileSize
+		}
+	}
 	file, w, cancelFn, err := c.connection.fs.Create(filePath, 0)
 	file, w, cancelFn, err := c.connection.fs.Create(filePath, 0)
 	if err != nil {
 	if err != nil {
 		c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err)
 		c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err)
@@ -216,6 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
 		transferError:  nil,
 		transferError:  nil,
 		isFinished:     false,
 		isFinished:     false,
 		minWriteOffset: 0,
 		minWriteOffset: 0,
+		initialSize:    initialSize,
 		lock:           new(sync.Mutex),
 		lock:           new(sync.Mutex),
 	}
 	}
 	addTransfer(&transfer)
 	addTransfer(&transfer)
@@ -246,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
 			c.sendErrorMessage(err.Error())
 			c.sendErrorMessage(err.Error())
 			return err
 			return err
 		}
 		}
-		return c.handleUploadFile(p, filePath, sizeToRead, true)
+		return c.handleUploadFile(p, filePath, sizeToRead, true, 0)
 	}
 	}
 
 
 	if statErr != nil {
 	if statErr != nil {
@@ -279,9 +288,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
 		}
 		}
 	}
 	}
 
 
-	dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false)
-
-	return c.handleUploadFile(p, filePath, sizeToRead, false)
+	return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size())
 }
 }
 
 
 func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {
 func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {

+ 4 - 4
sftpd/sftpd_test.go

@@ -3123,7 +3123,7 @@ func TestResolvePaths(t *testing.T) {
 		}
 		}
 		path = "../test/sub"
 		path = "../test/sub"
 		resolved, err = fs.ResolvePath(filepath.ToSlash(path))
 		resolved, err = fs.ResolvePath(filepath.ToSlash(path))
-		if fs.Name() == "osfs" {
+		if vfs.IsLocalOsFs(fs) {
 			if err == nil {
 			if err == nil {
 				t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
 				t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
 			}
 			}
@@ -3134,7 +3134,7 @@ func TestResolvePaths(t *testing.T) {
 		}
 		}
 		path = "../../../test/../sub"
 		path = "../../../test/../sub"
 		resolved, err = fs.ResolvePath(filepath.ToSlash(path))
 		resolved, err = fs.ResolvePath(filepath.ToSlash(path))
-		if fs.Name() == "osfs" {
+		if vfs.IsLocalOsFs(fs) {
 			if err == nil {
 			if err == nil {
 				t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
 				t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
 			}
 			}
@@ -4624,7 +4624,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ
 	content := []byte("#!/bin/sh\n\n")
 	content := []byte("#!/bin/sh\n\n")
 	q, _ := json.Marshal(questions)
 	q, _ := json.Marshal(questions)
 	echos := []bool{}
 	echos := []bool{}
-	for index, _ := range questions {
+	for index := range questions {
 		echos = append(echos, index%2 == 0)
 		echos = append(echos, index%2 == 0)
 	}
 	}
 	e, _ := json.Marshal(echos)
 	e, _ := json.Marshal(echos)
@@ -4633,7 +4633,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ
 	} else {
 	} else {
 		content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...)
 		content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...)
 	}
 	}
-	for index, _ := range questions {
+	for index := range questions {
 		content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...)
 		content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...)
 	}
 	}
 	if sleepTime > 0 {
 	if sleepTime > 0 {

+ 14 - 3
sftpd/transfer.go

@@ -44,6 +44,7 @@ type Transfer struct {
 	isFinished     bool
 	isFinished     bool
 	minWriteOffset int64
 	minWriteOffset int64
 	expectedSize   int64
 	expectedSize   int64
+	initialSize    int64
 	lock           *sync.Mutex
 	lock           *sync.Mutex
 }
 }
 
 
@@ -163,9 +164,7 @@ func (t *Transfer) Close() error {
 	}
 	}
 	metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
 	metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
 	removeTransfer(t)
 	removeTransfer(t)
-	if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
-		dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false)
-	}
+	t.updateQuota(numFiles)
 	return err
 	return err
 }
 }
 
 
@@ -181,6 +180,18 @@ func (t *Transfer) closeIO() error {
 	return err
 	return err
 }
 }
 
 
+func (t *Transfer) updateQuota(numFiles int) bool {
+	// S3 uploads are atomic, if there is an error nothing is uploaded
+	if t.file == nil && t.transferError != nil {
+		return false
+	}
+	if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
+		dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false)
+		return true
+	}
+	return false
+}
+
 func (t *Transfer) checkDownloadSize() {
 func (t *Transfer) checkDownloadSize() {
 	if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize {
 	if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize {
 		t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize)
 		t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize)