浏览代码

check quota usage between ongoing transfers

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 3 年之前
父节点
当前提交
d2a4178846

+ 101 - 30
common/common.go

@@ -53,9 +53,10 @@ const (
 	operationMkdir     = "mkdir"
 	operationRmdir     = "rmdir"
 	// SSH command action name
-	OperationSSHCmd          = "ssh_cmd"
-	chtimesFormat            = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
-	idleTimeoutCheckInterval = 3 * time.Minute
+	OperationSSHCmd              = "ssh_cmd"
+	chtimesFormat                = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
+	idleTimeoutCheckInterval     = 3 * time.Minute
+	periodicTimeoutCheckInterval = 1 * time.Minute
 )
 
 // Stat flags
@@ -110,6 +111,7 @@ var (
 	ErrCrtRevoked           = errors.New("your certificate has been revoked")
 	ErrNoCredentials        = errors.New("no credential provided")
 	ErrInternalFailure      = errors.New("internal failure")
+	ErrTransferAborted      = errors.New("transfer aborted")
 	errNoTransfer           = errors.New("requested transfer not found")
 	errTransferMismatch     = errors.New("transfer mismatch")
 )
@@ -120,10 +122,11 @@ var (
 	// Connections is the list of active connections
 	Connections ActiveConnections
 	// QuotaScans is the list of active quota scans
-	QuotaScans            ActiveScans
-	idleTimeoutTicker     *time.Ticker
-	idleTimeoutTickerDone chan bool
-	supportedProtocols    = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
+	QuotaScans                ActiveScans
+	transfersChecker          TransfersChecker
+	periodicTimeoutTicker     *time.Ticker
+	periodicTimeoutTickerDone chan bool
+	supportedProtocols        = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
 		ProtocolHTTP, ProtocolHTTPShare}
 	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
 	// the map key is the protocol, for each protocol we can have multiple rate limiters
@@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
 	Config = c
 	Config.idleLoginTimeout = 2 * time.Minute
 	Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
-	if Config.IdleTimeout > 0 {
-		startIdleTimeoutTicker(idleTimeoutCheckInterval)
-	}
+	startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
 	Config.defender = nil
 	rateLimiters = make(map[string][]*rateLimiter)
 	for _, rlCfg := range c.RateLimitersConfig {
@@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
 	}
 	vfs.SetTempPath(c.TempPath)
 	dataprovider.SetTempPath(c.TempPath)
+	transfersChecker = getTransfersChecker()
 	return nil
 }
 
@@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
 }
 
 // the ticker cannot be started/stopped from multiple goroutines
-func startIdleTimeoutTicker(duration time.Duration) {
-	stopIdleTimeoutTicker()
-	idleTimeoutTicker = time.NewTicker(duration)
-	idleTimeoutTickerDone = make(chan bool)
+func startPeriodicTimeoutTicker(duration time.Duration) {
+	stopPeriodicTimeoutTicker()
+	periodicTimeoutTicker = time.NewTicker(duration)
+	periodicTimeoutTickerDone = make(chan bool)
 	go func() {
+		counter := int64(0)
+		ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
 		for {
 			select {
-			case <-idleTimeoutTickerDone:
+			case <-periodicTimeoutTickerDone:
 				return
-			case <-idleTimeoutTicker.C:
-				Connections.checkIdles()
+			case <-periodicTimeoutTicker.C:
+				counter++
+				if Config.IdleTimeout > 0 && counter >= int64(ratio) {
+					counter = 0
+					Connections.checkIdles()
+				}
+				go Connections.checkTransfers()
 			}
 		}
 	}()
 }
 
-func stopIdleTimeoutTicker() {
-	if idleTimeoutTicker != nil {
-		idleTimeoutTicker.Stop()
-		idleTimeoutTickerDone <- true
-		idleTimeoutTicker = nil
+func stopPeriodicTimeoutTicker() {
+	if periodicTimeoutTicker != nil {
+		periodicTimeoutTicker.Stop()
+		periodicTimeoutTickerDone <- true
+		periodicTimeoutTicker = nil
 	}
 }
 
 // ActiveTransfer defines the interface for the current active transfers
 type ActiveTransfer interface {
-	GetID() uint64
+	GetID() int64
 	GetType() int
 	GetSize() int64
+	GetDownloadedSize() int64
+	GetUploadedSize() int64
 	GetVirtualPath() string
 	GetStartTime() time.Time
-	SignalClose()
+	SignalClose(err error)
 	Truncate(fsPath string, size int64) (int64, error)
 	GetRealFsPath(fsPath string) string
 	SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
+	GetTruncatedSize() int64
+	GetMaxAllowedSize() int64
 }
 
 // ActiveConnection defines the interface for the current active connections
@@ -319,6 +332,7 @@ type ActiveConnection interface {
 	AddTransfer(t ActiveTransfer)
 	RemoveTransfer(t ActiveTransfer)
 	GetTransfers() []ConnectionTransfer
+	SignalTransferClose(transferID int64, err error)
 	CloseFS() error
 }
 
@@ -335,11 +349,14 @@ type StatAttributes struct {
 
 // ConnectionTransfer defines the trasfer details to expose
 type ConnectionTransfer struct {
-	ID            uint64 `json:"-"`
-	OperationType string `json:"operation_type"`
-	StartTime     int64  `json:"start_time"`
-	Size          int64  `json:"size"`
-	VirtualPath   string `json:"path"`
+	ID             int64  `json:"-"`
+	OperationType  string `json:"operation_type"`
+	StartTime      int64  `json:"start_time"`
+	Size           int64  `json:"size"`
+	VirtualPath    string `json:"path"`
+	MaxAllowedSize int64  `json:"-"`
+	ULSize         int64  `json:"-"`
+	DLSize         int64  `json:"-"`
 }
 
 func (t *ConnectionTransfer) getConnectionTransferAsString() string {
@@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error {
 type ActiveConnections struct {
 	// clients contains both authenticated and estabilished connections and the ones waiting
 	// for authentication
-	clients clientsMap
+	clients              clientsMap
+	transfersCheckStatus int32
 	sync.RWMutex
 	connections    []ActiveConnection
 	sshConnections []*SSHConnection
@@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
 	conns.RUnlock()
 }
 
+func (conns *ActiveConnections) checkTransfers() {
+	if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
+		logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
+		return
+	}
+	atomic.StoreInt32(&conns.transfersCheckStatus, 1)
+	defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
+
+	var wg sync.WaitGroup
+
+	logger.Debug(logSender, "", "start concurrent transfers check")
+	conns.RLock()
+
+	// update the current size for transfers to monitors
+	for _, c := range conns.connections {
+		for _, t := range c.GetTransfers() {
+			if t.MaxAllowedSize > 0 {
+				wg.Add(1)
+
+				go func(transfer ConnectionTransfer, connID string) {
+					defer wg.Done()
+					transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
+				}(t, c.GetID())
+			}
+		}
+	}
+
+	conns.RUnlock()
+	logger.Debug(logSender, "", "waiting for the update of the transfers current size")
+	wg.Wait()
+
+	logger.Debug(logSender, "", "getting overquota transfers")
+	overquotaTransfers := transfersChecker.GetOverquotaTransfers()
+	logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
+	if len(overquotaTransfers) == 0 {
+		return
+	}
+
+	conns.RLock()
+	defer conns.RUnlock()
+
+	for _, c := range conns.connections {
+		for _, overquotaTransfer := range overquotaTransfers {
+			if c.GetID() == overquotaTransfer.ConnID {
+				logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
+					c.GetUsername(), overquotaTransfer.TransferID)
+				c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
+			}
+		}
+	}
+	logger.Debug(logSender, "", "transfers check completed")
+}
+
 // AddClientConnection stores a new client connection
 func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
 	conns.clients.add(ipAddr)

+ 7 - 7
common/common_test.go

@@ -408,19 +408,19 @@ func TestIdleConnections(t *testing.T) {
 	assert.Len(t, Connections.sshConnections, 2)
 	Connections.RUnlock()
 
-	startIdleTimeoutTicker(100 * time.Millisecond)
+	startPeriodicTimeoutTicker(100 * time.Millisecond)
 	assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
 	assert.Eventually(t, func() bool {
 		Connections.RLock()
 		defer Connections.RUnlock()
 		return len(Connections.sshConnections) == 1
 	}, 1*time.Second, 200*time.Millisecond)
-	stopIdleTimeoutTicker()
+	stopPeriodicTimeoutTicker()
 	assert.Len(t, Connections.GetStats(), 2)
 	c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
 	cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
 	sshConn2.lastActivity = c.lastActivity
-	startIdleTimeoutTicker(100 * time.Millisecond)
+	startPeriodicTimeoutTicker(100 * time.Millisecond)
 	assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
 	assert.Eventually(t, func() bool {
 		Connections.RLock()
@@ -428,7 +428,7 @@ func TestIdleConnections(t *testing.T) {
 		return len(Connections.sshConnections) == 0
 	}, 1*time.Second, 200*time.Millisecond)
 	assert.Equal(t, int32(0), Connections.GetClientConnections())
-	stopIdleTimeoutTicker()
+	stopPeriodicTimeoutTicker()
 	assert.True(t, customConn1.isClosed)
 	assert.True(t, customConn2.isClosed)
 
@@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) {
 	fakeConn1 := &fakeConnection{
 		BaseConnection: c1,
 	}
-	t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs)
+	t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs)
 	t1.BytesReceived = 123
-	t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
+	t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
 	t2.BytesSent = 456
 	c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
 	fakeConn2 := &fakeConnection{
@@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) {
 		BaseConnection: c3,
 		command:        "PROPFIND",
 	}
-	t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
+	t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
 	Connections.Add(fakeConn1)
 	Connections.Add(fakeConn2)
 	Connections.Add(fakeConn3)

+ 58 - 12
common/connection.go

@@ -27,7 +27,7 @@ type BaseConnection struct {
 	lastActivity int64
 	// unique ID for a transfer.
 	// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
-	transferID uint64
+	transferID int64
 	// Unique identifier for the connection
 	ID string
 	// user associated with this connection if any
@@ -66,8 +66,8 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interfac
 }
 
 // GetTransferID returns an unique transfer ID for this connection
-func (c *BaseConnection) GetTransferID() uint64 {
-	return atomic.AddUint64(&c.transferID, 1)
+func (c *BaseConnection) GetTransferID() int64 {
+	return atomic.AddInt64(&c.transferID, 1)
 }
 
 // GetID returns the connection ID
@@ -125,6 +125,27 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
 
 	c.activeTransfers = append(c.activeTransfers, t)
 	c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers))
+	if t.GetMaxAllowedSize() > 0 {
+		folderName := ""
+		if t.GetType() == TransferUpload {
+			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath()))
+			if err == nil {
+				if !vfolder.IsIncludedInUserQuota() {
+					folderName = vfolder.Name
+				}
+			}
+		}
+		go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{
+			ID:            t.GetID(),
+			Type:          t.GetType(),
+			ConnID:        c.ID,
+			Username:      c.GetUsername(),
+			FolderName:    folderName,
+			TruncatedSize: t.GetTruncatedSize(),
+			CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+			UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		})
+	}
 }
 
 // RemoveTransfer removes the specified transfer from the active ones
@@ -132,6 +153,10 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
 	c.Lock()
 	defer c.Unlock()
 
+	if t.GetMaxAllowedSize() > 0 {
+		go transfersChecker.RemoveTransfer(t.GetID(), c.ID)
+	}
+
 	for idx, transfer := range c.activeTransfers {
 		if transfer.GetID() == t.GetID() {
 			lastIdx := len(c.activeTransfers) - 1
@@ -145,6 +170,20 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
 	c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID())
 }
 
+// SignalTransferClose makes the transfer fail on the next read/write with the
+// specified error
+func (c *BaseConnection) SignalTransferClose(transferID int64, err error) {
+	c.RLock()
+	defer c.RUnlock()
+
+	for _, t := range c.activeTransfers {
+		if t.GetID() == transferID {
+			c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID)
+			t.SignalClose(err)
+		}
+	}
+}
+
 // GetTransfers returns the active transfers
 func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
 	c.RLock()
@@ -160,11 +199,14 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
 			operationType = operationUpload
 		}
 		transfers = append(transfers, ConnectionTransfer{
-			ID:            t.GetID(),
-			OperationType: operationType,
-			StartTime:     util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
-			Size:          t.GetSize(),
-			VirtualPath:   t.GetVirtualPath(),
+			ID:             t.GetID(),
+			OperationType:  operationType,
+			StartTime:      util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
+			Size:           t.GetSize(),
+			VirtualPath:    t.GetVirtualPath(),
+			MaxAllowedSize: t.GetMaxAllowedSize(),
+			ULSize:         t.GetUploadedSize(),
+			DLSize:         t.GetDownloadedSize(),
 		})
 	}
 
@@ -181,7 +223,7 @@ func (c *BaseConnection) SignalTransfersAbort() error {
 	}
 
 	for _, t := range c.activeTransfers {
-		t.SignalClose()
+		t.SignalClose(ErrTransferAborted)
 	}
 	return nil
 }
@@ -1208,9 +1250,8 @@ func (c *BaseConnection) GetOpUnsupportedError() error {
 	}
 }
 
-// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
-func (c *BaseConnection) GetQuotaExceededError() error {
-	switch c.protocol {
+func getQuotaExceededError(protocol string) error {
+	switch protocol {
 	case ProtocolSFTP:
 		return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error())
 	case ProtocolFTP:
@@ -1220,6 +1261,11 @@ func (c *BaseConnection) GetQuotaExceededError() error {
 	}
 }
 
+// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
+func (c *BaseConnection) GetQuotaExceededError() error {
+	return getQuotaExceededError(c.protocol)
+}
+
 // IsQuotaExceededError returns true if the given error is a quota exceeded error
 func (c *BaseConnection) IsQuotaExceededError(err error) bool {
 	switch c.protocol {

+ 46 - 8
common/transfer.go

@@ -20,7 +20,7 @@ var (
 
 // BaseTransfer contains protocols common transfer details for an upload or a download.
 type BaseTransfer struct { //nolint:maligned
-	ID              uint64
+	ID              int64
 	BytesSent       int64
 	BytesReceived   int64
 	Fs              vfs.Fs
@@ -35,18 +35,21 @@ type BaseTransfer struct { //nolint:maligned
 	MaxWriteSize    int64
 	MinWriteOffset  int64
 	InitialSize     int64
+	truncatedSize   int64
 	isNewFile       bool
 	transferType    int
 	AbortTransfer   int32
 	aTime           time.Time
 	mTime           time.Time
 	sync.Mutex
+	errAbort    error
 	ErrTransfer error
 }
 
 // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection
 func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string,
-	transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer {
+	transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs,
+) *BaseTransfer {
 	t := &BaseTransfer{
 		ID:              conn.GetTransferID(),
 		File:            file,
@@ -64,6 +67,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
 		BytesReceived:   0,
 		MaxWriteSize:    maxWriteSize,
 		AbortTransfer:   0,
+		truncatedSize:   truncatedSize,
 		Fs:              fs,
 	}
 
@@ -77,7 +81,7 @@ func (t *BaseTransfer) SetFtpMode(mode string) {
 }
 
 // GetID returns the transfer ID
-func (t *BaseTransfer) GetID() uint64 {
+func (t *BaseTransfer) GetID() int64 {
 	return t.ID
 }
 
@@ -94,19 +98,53 @@ func (t *BaseTransfer) GetSize() int64 {
 	return atomic.LoadInt64(&t.BytesReceived)
 }
 
+// GetDownloadedSize returns the transferred size
+func (t *BaseTransfer) GetDownloadedSize() int64 {
+	return atomic.LoadInt64(&t.BytesSent)
+}
+
+// GetUploadedSize returns the transferred size
+func (t *BaseTransfer) GetUploadedSize() int64 {
+	return atomic.LoadInt64(&t.BytesReceived)
+}
+
 // GetStartTime returns the start time
 func (t *BaseTransfer) GetStartTime() time.Time {
 	return t.start
 }
 
-// SignalClose signals that the transfer should be closed.
-// For same protocols, for example WebDAV, we have no
-// access to the network connection, so we use this method
-// to make the next read or write to fail
-func (t *BaseTransfer) SignalClose() {
+// GetAbortError returns the error to send to the client if the transfer was aborted
+func (t *BaseTransfer) GetAbortError() error {
+	t.Lock()
+	defer t.Unlock()
+
+	if t.errAbort != nil {
+		return t.errAbort
+	}
+	return getQuotaExceededError(t.Connection.protocol)
+}
+
+// SignalClose signals that the transfer should be closed after the next read/write.
+// The optional error argument allow to send a specific error, otherwise a generic
+// transfer aborted error is sent
+func (t *BaseTransfer) SignalClose(err error) {
+	t.Lock()
+	t.errAbort = err
+	t.Unlock()
 	atomic.StoreInt32(&(t.AbortTransfer), 1)
 }
 
+// GetTruncatedSize returns the truncated sized if this is an upload overwriting
+// an existing file
+func (t *BaseTransfer) GetTruncatedSize() int64 {
+	return t.truncatedSize
+}
+
+// GetMaxAllowedSize returns the max allowed size
+func (t *BaseTransfer) GetMaxAllowedSize() int64 {
+	return t.MaxWriteSize
+}
+
 // GetVirtualPath returns the transfer virtual path
 func (t *BaseTransfer) GetVirtualPath() string {
 	return t.requestPath

+ 15 - 10
common/transfer_test.go

@@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) {
 	wantedUploadElapsed -= wantedDownloadElapsed / 10
 	wantedDownloadElapsed -= wantedDownloadElapsed / 10
 	conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
-	transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, true, fs)
+	transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs)
 	transfer.BytesReceived = testFileSize
 	transfer.Connection.UpdateLastActivity()
 	startTime := transfer.Connection.GetLastActivity()
@@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) {
 	err := transfer.Close()
 	assert.NoError(t, err)
 
-	transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, true, fs)
+	transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs)
 	transfer.BytesSent = testFileSize
 	transfer.Connection.UpdateLastActivity()
 	startTime = transfer.Connection.GetLastActivity()
@@ -101,7 +101,8 @@ func TestRealPath(t *testing.T) {
 	file, err := os.Create(testFile)
 	require.NoError(t, err)
 	conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
-	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
+	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file",
+		TransferUpload, 0, 0, 0, 0, true, fs)
 	rPath := transfer.GetRealFsPath(testFile)
 	assert.Equal(t, testFile, rPath)
 	rPath = conn.getRealFsPath(testFile)
@@ -138,7 +139,8 @@ func TestTruncate(t *testing.T) {
 	_, err = file.Write([]byte("hello"))
 	assert.NoError(t, err)
 	conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
-	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs)
+	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5,
+		100, 0, false, fs)
 
 	err = conn.SetStat("/transfer_test_file", &StatAttributes{
 		Size:  2,
@@ -155,7 +157,8 @@ func TestTruncate(t *testing.T) {
 		assert.Equal(t, int64(2), fi.Size())
 	}
 
-	transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs)
+	transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0,
+		100, 0, true, fs)
 	// file.Stat will fail on a closed file
 	err = conn.SetStat("/transfer_test_file", &StatAttributes{
 		Size:  2,
@@ -165,7 +168,7 @@ func TestTruncate(t *testing.T) {
 	err = transfer.Close()
 	assert.NoError(t, err)
 
-	transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, true, fs)
+	transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs)
 	_, err = transfer.Truncate("mismatch", 0)
 	assert.EqualError(t, err, errTransferMismatch.Error())
 	_, err = transfer.Truncate(testFile, 0)
@@ -202,7 +205,8 @@ func TestTransferErrors(t *testing.T) {
 		assert.FailNow(t, "unable to open test file")
 	}
 	conn := NewBaseConnection("id", ProtocolSFTP, "", "", u)
-	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
+	transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
+		0, 0, 0, 0, true, fs)
 	assert.Nil(t, transfer.cancelFn)
 	assert.Equal(t, testFile, transfer.GetFsPath())
 	transfer.SetCancelFn(cancelFn)
@@ -228,7 +232,7 @@ func TestTransferErrors(t *testing.T) {
 		assert.FailNow(t, "unable to open test file")
 	}
 	fsPath := filepath.Join(os.TempDir(), "test_file")
-	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
+	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
 	transfer.BytesReceived = 9
 	transfer.TransferError(errFake)
 	assert.Error(t, transfer.ErrTransfer, errFake.Error())
@@ -247,7 +251,7 @@ func TestTransferErrors(t *testing.T) {
 	if !assert.NoError(t, err) {
 		assert.FailNow(t, "unable to open test file")
 	}
-	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
+	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
 	transfer.BytesReceived = 9
 	// the file is closed from the embedding struct before to call close
 	err = file.Close()
@@ -273,7 +277,8 @@ func TestRemovePartialCryptoFile(t *testing.T) {
 		},
 	}
 	conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
-	transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
+	transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
+		0, 0, 0, 0, true, fs)
 	transfer.ErrTransfer = errors.New("test error")
 	_, err = transfer.getUploadFileSize()
 	assert.Error(t, err)

+ 167 - 0
common/transferschecker.go

@@ -0,0 +1,167 @@
+package common
+
+import (
+	"errors"
+	"sync"
+	"time"
+
+	"github.com/drakkan/sftpgo/v2/dataprovider"
+	"github.com/drakkan/sftpgo/v2/logger"
+	"github.com/drakkan/sftpgo/v2/util"
+)
+
+type overquotaTransfer struct {
+	ConnID     string
+	TransferID int64
+}
+
+// TransfersChecker defines the interface that transfer checkers must implement.
+// A transfer checker ensure that multiple concurrent transfers does not exceeded
+// the remaining user quota
+type TransfersChecker interface {
+	AddTransfer(transfer dataprovider.ActiveTransfer)
+	RemoveTransfer(ID int64, connectionID string)
+	UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string)
+	GetOverquotaTransfers() []overquotaTransfer
+}
+
+func getTransfersChecker() TransfersChecker {
+	return &transfersCheckerMem{}
+}
+
+type transfersCheckerMem struct {
+	sync.RWMutex
+	transfers []dataprovider.ActiveTransfer
+}
+
+func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
+	t.Lock()
+	defer t.Unlock()
+
+	t.transfers = append(t.transfers, transfer)
+}
+
+func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
+	t.Lock()
+	defer t.Unlock()
+
+	for idx, transfer := range t.transfers {
+		if transfer.ID == ID && transfer.ConnID == connectionID {
+			lastIdx := len(t.transfers) - 1
+			t.transfers[idx] = t.transfers[lastIdx]
+			t.transfers = t.transfers[:lastIdx]
+			return
+		}
+	}
+}
+
+func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) {
+	t.Lock()
+	defer t.Unlock()
+
+	for idx := range t.transfers {
+		if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
+			t.transfers[idx].CurrentDLSize = dlSize
+			t.transfers[idx].CurrentULSize = ulSize
+			t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
+			return
+		}
+	}
+}
+
+func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
+	var result int64
+
+	if folderName != "" {
+		for _, folder := range user.VirtualFolders {
+			if folder.Name == folderName {
+				if folder.QuotaSize > 0 {
+					return folder.QuotaSize - folder.UsedQuotaSize, nil
+				}
+			}
+		}
+	} else {
+		if user.QuotaSize > 0 {
+			return user.QuotaSize - user.UsedQuotaSize, nil
+		}
+	}
+
+	return result, errors.New("no quota limit defined")
+}
+
+func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
+	t.RLock()
+	defer t.RUnlock()
+
+	usersToFetch := make(map[string]bool)
+	aggregations := make(map[string][]dataprovider.ActiveTransfer)
+	for _, transfer := range t.transfers {
+		key := transfer.GetKey()
+		aggregations[key] = append(aggregations[key], transfer)
+		if len(aggregations[key]) > 1 {
+			if transfer.FolderName != "" {
+				usersToFetch[transfer.Username] = true
+			} else {
+				if _, ok := usersToFetch[transfer.Username]; !ok {
+					usersToFetch[transfer.Username] = false
+				}
+			}
+		}
+	}
+
+	return usersToFetch, aggregations
+}
+
+func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
+	usersToFetch, aggregations := t.aggregateTransfers()
+
+	if len(usersToFetch) == 0 {
+		return nil
+	}
+
+	users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
+	if err != nil {
+		logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
+		return nil
+	}
+
+	usersMap := make(map[string]dataprovider.User)
+
+	for _, user := range users {
+		usersMap[user.Username] = user
+	}
+
+	var overquotaTransfers []overquotaTransfer
+
+	for _, transfers := range aggregations {
+		if len(transfers) > 1 {
+			username := transfers[0].Username
+			folderName := transfers[0].FolderName
+			// transfer type is always upload for now
+			remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
+			if err != nil {
+				continue
+			}
+			var usedDiskQuota int64
+			for _, tr := range transfers {
+				// We optimistically assume that a cloud transfer that replaces an existing
+				// file will be successful
+				usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
+			}
+			logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v",
+				username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
+			if usedDiskQuota > remaningDiskQuota {
+				for _, tr := range transfers {
+					if tr.CurrentULSize > tr.TruncatedSize {
+						overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
+							ConnID:     tr.ConnID,
+							TransferID: tr.ID,
+						})
+					}
+				}
+			}
+		}
+	}
+
+	return overquotaTransfers
+}

+ 449 - 0
common/transferschecker_test.go

@@ -0,0 +1,449 @@
+package common
+
+import (
+	"fmt"
+	"os"
+	"path"
+	"path/filepath"
+	"strconv"
+	"strings"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/rs/xid"
+	"github.com/sftpgo/sdk"
+	"github.com/stretchr/testify/assert"
+
+	"github.com/drakkan/sftpgo/v2/dataprovider"
+	"github.com/drakkan/sftpgo/v2/util"
+	"github.com/drakkan/sftpgo/v2/vfs"
+)
+
+func TestTransfersCheckerDiskQuota(t *testing.T) {
+	username := "transfers_check_username"
+	folderName := "test_transfers_folder"
+	vdirPath := "/vdir"
+	user := dataprovider.User{
+		BaseUser: sdk.BaseUser{
+			Username:  username,
+			Password:  "testpwd",
+			HomeDir:   filepath.Join(os.TempDir(), username),
+			Status:    1,
+			QuotaSize: 120,
+			Permissions: map[string][]string{
+				"/": {dataprovider.PermAny},
+			},
+		},
+		VirtualFolders: []vfs.VirtualFolder{
+			{
+				BaseVirtualFolder: vfs.BaseVirtualFolder{
+					Name:       folderName,
+					MappedPath: filepath.Join(os.TempDir(), folderName),
+				},
+				VirtualPath: vdirPath,
+				QuotaSize:   100,
+			},
+		},
+	}
+
+	err := dataprovider.AddUser(&user, "", "")
+	assert.NoError(t, err)
+	user, err = dataprovider.UserExists(username)
+	assert.NoError(t, err)
+
+	connID1 := xid.New().String()
+	fsUser, err := user.GetFilesystemForPath("/file1", connID1)
+	assert.NoError(t, err)
+	conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user)
+	fakeConn1 := &fakeConnection{
+		BaseConnection: conn1,
+	}
+	transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
+		"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser)
+	transfer1.BytesReceived = 150
+	Connections.Add(fakeConn1)
+	// the transferschecker will do nothing if there is only one ongoing transfer
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+
+	connID2 := xid.New().String()
+	conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user)
+	fakeConn2 := &fakeConnection{
+		BaseConnection: conn2,
+	}
+	transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
+		"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser)
+	transfer1.BytesReceived = 50
+	transfer2.BytesReceived = 60
+	Connections.Add(fakeConn2)
+
+	connID3 := xid.New().String()
+	conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
+	fakeConn3 := &fakeConnection{
+		BaseConnection: conn3,
+	}
+	transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
+		"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser)
+	transfer3.BytesReceived = 60 // this value will be ignored, this is a download
+	Connections.Add(fakeConn3)
+
+	// the transfers are not overquota
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+
+	transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+	transfer1.BytesReceived = 120
+	// we are now overquota
+	// if another check is in progress nothing is done
+	atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+	atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
+
+	Connections.checkTransfers()
+	assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
+	assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
+	assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError()))
+	assert.Nil(t, transfer3.errAbort)
+	assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError()))
+	// update the user quota size
+	user.QuotaSize = 1000
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+	transfer1.errAbort = nil
+	transfer2.errAbort = nil
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+
+	user.QuotaSize = 0
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+	// now check a public folder
+	transfer1.BytesReceived = 0
+	transfer2.BytesReceived = 0
+	connID4 := xid.New().String()
+	fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
+	assert.NoError(t, err)
+	conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
+	fakeConn4 := &fakeConnection{
+		BaseConnection: conn4,
+	}
+	transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
+		filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
+		100, 0, true, fsFolder)
+	Connections.Add(fakeConn4)
+	connID5 := xid.New().String()
+	conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
+	fakeConn5 := &fakeConnection{
+		BaseConnection: conn5,
+	}
+	transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"),
+		filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
+		100, 0, true, fsFolder)
+
+	Connections.Add(fakeConn5)
+	transfer4.BytesReceived = 50
+	transfer5.BytesReceived = 40
+	Connections.checkTransfers()
+	assert.Nil(t, transfer4.errAbort)
+	assert.Nil(t, transfer5.errAbort)
+	transfer5.BytesReceived = 60
+	Connections.checkTransfers()
+	assert.Nil(t, transfer1.errAbort)
+	assert.Nil(t, transfer2.errAbort)
+	assert.Nil(t, transfer3.errAbort)
+	assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort))
+	assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort))
+
+	if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName {
+		providerConf := dataprovider.GetProviderConfig()
+		err = dataprovider.Close()
+		assert.NoError(t, err)
+
+		transfer4.errAbort = nil
+		transfer5.errAbort = nil
+		Connections.checkTransfers()
+		assert.Nil(t, transfer1.errAbort)
+		assert.Nil(t, transfer2.errAbort)
+		assert.Nil(t, transfer3.errAbort)
+		assert.Nil(t, transfer4.errAbort)
+		assert.Nil(t, transfer5.errAbort)
+
+		err = dataprovider.Initialize(providerConf, configDir, true)
+		assert.NoError(t, err)
+	}
+
+	Connections.Remove(fakeConn1.GetID())
+	Connections.Remove(fakeConn2.GetID())
+	Connections.Remove(fakeConn3.GetID())
+	Connections.Remove(fakeConn4.GetID())
+	Connections.Remove(fakeConn5.GetID())
+	stats := Connections.GetStats()
+	assert.Len(t, stats, 0)
+
+	err = dataprovider.DeleteUser(user.Username, "", "")
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	err = dataprovider.DeleteFolder(folderName, "", "")
+	assert.NoError(t, err)
+	err = os.RemoveAll(filepath.Join(os.TempDir(), folderName))
+	assert.NoError(t, err)
+}
+
+func TestAggregateTransfers(t *testing.T) {
+	checker := transfersCheckerMem{}
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "1",
+		Username:      "user",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 100,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+	usersToFetch, aggregations := checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 0)
+	assert.Len(t, aggregations, 1)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferDownload,
+		ConnID:        "2",
+		Username:      "user",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 0,
+		CurrentDLSize: 100,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 0)
+	assert.Len(t, aggregations, 2)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "3",
+		Username:      "user",
+		FolderName:    "folder",
+		TruncatedSize: 0,
+		CurrentULSize: 10,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 0)
+	assert.Len(t, aggregations, 3)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "4",
+		Username:      "user1",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 100,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 0)
+	assert.Len(t, aggregations, 4)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "5",
+		Username:      "user",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 100,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 1)
+	val, ok := usersToFetch["user"]
+	assert.True(t, ok)
+	assert.False(t, val)
+	assert.Len(t, aggregations, 4)
+	aggregate, ok := aggregations["user0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 2)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "6",
+		Username:      "user",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 100,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 1)
+	val, ok = usersToFetch["user"]
+	assert.True(t, ok)
+	assert.False(t, val)
+	assert.Len(t, aggregations, 4)
+	aggregate, ok = aggregations["user0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 3)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "7",
+		Username:      "user",
+		FolderName:    "folder",
+		TruncatedSize: 0,
+		CurrentULSize: 10,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 1)
+	val, ok = usersToFetch["user"]
+	assert.True(t, ok)
+	assert.True(t, val)
+	assert.Len(t, aggregations, 4)
+	aggregate, ok = aggregations["user0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 3)
+	aggregate, ok = aggregations["userfolder0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 2)
+
+	checker.AddTransfer(dataprovider.ActiveTransfer{
+		ID:            1,
+		Type:          TransferUpload,
+		ConnID:        "8",
+		Username:      "user",
+		FolderName:    "",
+		TruncatedSize: 0,
+		CurrentULSize: 100,
+		CurrentDLSize: 0,
+		CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+		UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
+	})
+
+	usersToFetch, aggregations = checker.aggregateTransfers()
+	assert.Len(t, usersToFetch, 1)
+	val, ok = usersToFetch["user"]
+	assert.True(t, ok)
+	assert.True(t, val)
+	assert.Len(t, aggregations, 4)
+	aggregate, ok = aggregations["user0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 4)
+	aggregate, ok = aggregations["userfolder0"]
+	assert.True(t, ok)
+	assert.Len(t, aggregate, 2)
+}
+
+func TestGetUsersForQuotaCheck(t *testing.T) {
+	usersToFetch := make(map[string]bool)
+	for i := 0; i < 50; i++ {
+		usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0
+	}
+
+	users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
+	assert.NoError(t, err)
+	assert.Len(t, users, 0)
+
+	for i := 0; i < 40; i++ {
+		user := dataprovider.User{
+			BaseUser: sdk.BaseUser{
+				Username:  fmt.Sprintf("user%v", i),
+				Password:  "pwd",
+				HomeDir:   filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)),
+				Status:    1,
+				QuotaSize: 120,
+				Permissions: map[string][]string{
+					"/": {dataprovider.PermAny},
+				},
+			},
+			VirtualFolders: []vfs.VirtualFolder{
+				{
+					BaseVirtualFolder: vfs.BaseVirtualFolder{
+						Name:       fmt.Sprintf("f%v", i),
+						MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)),
+					},
+					VirtualPath: "/vfolder",
+					QuotaSize:   100,
+				},
+			},
+		}
+		err = dataprovider.AddUser(&user, "", "")
+		assert.NoError(t, err)
+		err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false)
+		assert.NoError(t, err)
+	}
+
+	users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
+	assert.NoError(t, err)
+	assert.Len(t, users, 40)
+
+	for _, user := range users {
+		userIdxStr := strings.Replace(user.Username, "user", "", 1)
+		userIdx, err := strconv.Atoi(userIdxStr)
+		assert.NoError(t, err)
+		if userIdx%2 == 0 {
+			if assert.Len(t, user.VirtualFolders, 1, user.Username) {
+				assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize)
+				assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize)
+			}
+		} else {
+			switch dataprovider.GetProviderStatus().Driver {
+			case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
+				dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
+				assert.Len(t, user.VirtualFolders, 0, user.Username)
+			}
+		}
+	}
+
+	for i := 0; i < 40; i++ {
+		err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "")
+		assert.NoError(t, err)
+		err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "")
+		assert.NoError(t, err)
+	}
+
+	users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
+	assert.NoError(t, err)
+	assert.Len(t, users, 0)
+}

+ 47 - 0
dataprovider/bolt.go

@@ -647,6 +647,53 @@ func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
 	return nil, nil
 }
 
+func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	users := make([]User, 0, 30)
+
+	err := p.dbHandle.View(func(tx *bolt.Tx) error {
+		bucket, err := getUsersBucket(tx)
+		if err != nil {
+			return err
+		}
+		foldersBucket, err := getFoldersBucket(tx)
+		if err != nil {
+			return err
+		}
+		cursor := bucket.Cursor()
+		for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
+			var user User
+			err := json.Unmarshal(v, &user)
+			if err != nil {
+				return err
+			}
+			needFolders, ok := toFetch[user.Username]
+			if !ok {
+				continue
+			}
+			if needFolders && len(user.VirtualFolders) > 0 {
+				var folders []vfs.VirtualFolder
+				for idx := range user.VirtualFolders {
+					folder := &user.VirtualFolders[idx]
+					baseFolder, err := folderExistsInternal(folder.Name, foldersBucket)
+					if err != nil {
+						continue
+					}
+					folder.BaseVirtualFolder = baseFolder
+					folders = append(folders, *folder)
+				}
+				user.VirtualFolders = folders
+			}
+
+			user.SetEmptySecretsIfNil()
+			user.PrepareForRendering()
+			users = append(users, user)
+		}
+		return nil
+	})
+
+	return users, err
+}
+
 func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) {
 	users := make([]User, 0, limit)
 	var err error

+ 26 - 0
dataprovider/dataprovider.go

@@ -381,6 +381,26 @@ func (c *Config) IsDefenderSupported() bool {
 	}
 }
 
+// ActiveTransfer defines an active protocol transfer
+type ActiveTransfer struct {
+	ID            int64
+	Type          int
+	ConnID        string
+	Username      string
+	FolderName    string
+	TruncatedSize int64
+	CurrentULSize int64
+	CurrentDLSize int64
+	CreatedAt     int64
+	UpdatedAt     int64
+}
+
+// GetKey returns an aggregation key.
+// The same key will be returned for similar transfers
+func (t *ActiveTransfer) GetKey() string {
+	return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type)
+}
+
 // DefenderEntry defines a defender entry
 type DefenderEntry struct {
 	ID      int64     `json:"-"`
@@ -476,6 +496,7 @@ type Provider interface {
 	getUsers(limit int, offset int, order string) ([]User, error)
 	dumpUsers() ([]User, error)
 	getRecentlyUpdatedUsers(after int64) ([]User, error)
+	getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error)
 	updateLastLogin(username string) error
 	updateAdminLastLogin(username string) error
 	setUpdatedAt(username string)
@@ -1268,6 +1289,11 @@ func GetUsers(limit, offset int, order string) ([]User, error) {
 	return provider.getUsers(limit, offset, order)
 }
 
+// GetUsersForQuotaCheck returns the users with the fields required for a quota check
+func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	return provider.getUsersForQuotaCheck(toFetch)
+}
+
 // AddFolder adds a new virtual folder.
 func AddFolder(folder *vfs.BaseVirtualFolder) error {
 	return provider.addFolder(folder)

+ 49 - 2
dataprovider/memory.go

@@ -349,6 +349,7 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) {
 	for _, username := range p.dbHandle.usernames {
 		u := p.dbHandle.users[username]
 		user := u.getACopy()
+		p.addVirtualFoldersToUser(&user)
 		err = addCredentialsToUser(&user)
 		if err != nil {
 			return users, err
@@ -376,6 +377,28 @@ func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
 	return nil, nil
 }
 
+func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	users := make([]User, 0, 30)
+	p.dbHandle.Lock()
+	defer p.dbHandle.Unlock()
+	if p.dbHandle.isClosed {
+		return users, errMemoryProviderClosed
+	}
+	for _, username := range p.dbHandle.usernames {
+		if val, ok := toFetch[username]; ok {
+			u := p.dbHandle.users[username]
+			user := u.getACopy()
+			if val {
+				p.addVirtualFoldersToUser(&user)
+			}
+			user.PrepareForRendering()
+			users = append(users, user)
+		}
+	}
+
+	return users, nil
+}
+
 func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) {
 	users := make([]User, 0, limit)
 	var err error
@@ -396,6 +419,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
 			}
 			u := p.dbHandle.users[username]
 			user := u.getACopy()
+			p.addVirtualFoldersToUser(&user)
 			user.PrepareForRendering()
 			users = append(users, user)
 			if len(users) >= limit {
@@ -411,6 +435,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
 			username := p.dbHandle.usernames[i]
 			u := p.dbHandle.users[username]
 			user := u.getACopy()
+			p.addVirtualFoldersToUser(&user)
 			user.PrepareForRendering()
 			users = append(users, user)
 			if len(users) >= limit {
@@ -427,7 +452,12 @@ func (p *MemoryProvider) userExists(username string) (User, error) {
 	if p.dbHandle.isClosed {
 		return User{}, errMemoryProviderClosed
 	}
-	return p.userExistsInternal(username)
+	user, err := p.userExistsInternal(username)
+	if err != nil {
+		return user, err
+	}
+	p.addVirtualFoldersToUser(&user)
+	return user, nil
 }
 
 func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
@@ -632,6 +662,22 @@ func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolde
 	return folders
 }
 
+func (p *MemoryProvider) addVirtualFoldersToUser(user *User) {
+	if len(user.VirtualFolders) > 0 {
+		var folders []vfs.VirtualFolder
+		for idx := range user.VirtualFolders {
+			folder := &user.VirtualFolders[idx]
+			baseFolder, err := p.folderExistsInternal(folder.Name)
+			if err != nil {
+				continue
+			}
+			folder.BaseVirtualFolder = baseFolder.GetACopy()
+			folders = append(folders, *folder)
+		}
+		user.VirtualFolders = folders
+	}
+}
+
 func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) {
 	folder, err := p.folderExistsInternal(folderName)
 	if err == nil {
@@ -655,7 +701,8 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold
 }
 
 func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64,
-	usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) {
+	usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error,
+) {
 	folder, err := p.folderExistsInternal(baseFolder.Name)
 	if err == nil {
 		// exists

+ 4 - 0
dataprovider/mysql.go

@@ -186,6 +186,10 @@ func (p *MySQLProvider) getUsers(limit int, offset int, order string) ([]User, e
 	return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
 }
 
+func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
+}
+
 func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
 	return sqlCommonDumpFolders(p.dbHandle)
 }

+ 4 - 0
dataprovider/pgsql.go

@@ -198,6 +198,10 @@ func (p *PGSQLProvider) getUsers(limit int, offset int, order string) ([]User, e
 	return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
 }
 
+func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
+}
+
 func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
 	return sqlCommonDumpFolders(p.dbHandle)
 }

+ 84 - 0
dataprovider/sqlcommon.go

@@ -939,6 +939,90 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
 	return getUsersWithVirtualFolders(ctx, users, dbHandle)
 }
 
+func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
+	users := make([]User, 0, 30)
+
+	usernames := make([]string, 0, len(toFetch))
+	for k := range toFetch {
+		usernames = append(usernames, k)
+	}
+
+	maxUsers := 30
+	for len(usernames) > 0 {
+		if maxUsers > len(usernames) {
+			maxUsers = len(usernames)
+		}
+		usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
+		if err != nil {
+			return users, err
+		}
+		users = append(users, usersRange...)
+		usernames = usernames[maxUsers:]
+	}
+
+	var usersWithFolders []User
+
+	validIdx := 0
+	for _, user := range users {
+		if toFetch[user.Username] {
+			usersWithFolders = append(usersWithFolders, user)
+		} else {
+			users[validIdx] = user
+			validIdx++
+		}
+	}
+	users = users[:validIdx]
+	if len(usersWithFolders) == 0 {
+		return users, nil
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
+	defer cancel()
+
+	usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
+	if err != nil {
+		return users, err
+	}
+	users = append(users, usersWithFolders...)
+	return users, nil
+}
+
+func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
+	users := make([]User, 0, len(usernames))
+	ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
+	defer cancel()
+
+	q := getUsersForQuotaCheckQuery(len(usernames))
+	stmt, err := dbHandle.PrepareContext(ctx, q)
+	if err != nil {
+		providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
+		return users, err
+	}
+	defer stmt.Close()
+
+	queryArgs := make([]interface{}, 0, len(usernames))
+	for idx := range usernames {
+		queryArgs = append(queryArgs, usernames[idx])
+	}
+
+	rows, err := stmt.QueryContext(ctx, queryArgs...)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	for rows.Next() {
+		var user User
+		err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize)
+		if err != nil {
+			return users, err
+		}
+		users = append(users, user)
+	}
+
+	return users, rows.Err()
+}
+
 func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
 	users := make([]User, 0, limit)
 	ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)

+ 4 - 0
dataprovider/sqlite.go

@@ -183,6 +183,10 @@ func (p *SQLiteProvider) getUsers(limit int, offset int, order string) ([]User,
 	return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
 }
 
+func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
+	return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
+}
+
 func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
 	return sqlCommonDumpFolders(p.dbHandle)
 }

+ 18 - 1
dataprovider/sqlqueries.go

@@ -21,7 +21,7 @@ const (
 
 func getSQLPlaceholders() []string {
 	var placeholders []string
-	for i := 1; i <= 30; i++ {
+	for i := 1; i <= 50; i++ {
 		if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
 			placeholders = append(placeholders, fmt.Sprintf("$%v", i))
 		} else {
@@ -263,6 +263,23 @@ func getUsersQuery(order string) string {
 		order, sqlPlaceholders[0], sqlPlaceholders[1])
 }
 
+func getUsersForQuotaCheckQuery(numArgs int) string {
+	var sb strings.Builder
+	for idx := 0; idx < numArgs; idx++ {
+		if sb.Len() == 0 {
+			sb.WriteString("(")
+		} else {
+			sb.WriteString(",")
+		}
+		sb.WriteString(sqlPlaceholders[idx])
+	}
+	if sb.Len() > 0 {
+		sb.WriteString(")")
+	}
+	return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`,
+		sqlTableUsers, sb.String())
+}
+
 func getRecentlyUpdatedUsersQuery() string {
 	return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
 }

+ 6 - 4
ftpd/handler.go

@@ -335,8 +335,8 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6
 		return nil, c.GetFsError(fs, err)
 	}
 
-	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload,
-		0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath,
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	baseTransfer.SetFtpMode(c.getFTPMode())
 	t := newTransfer(baseTransfer, nil, r, offset)
 
@@ -402,7 +402,7 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath,
 	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
+		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
 	baseTransfer.SetFtpMode(c.getFTPMode())
 	t := newTransfer(baseTransfer, w, nil, 0)
 
@@ -452,6 +452,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
 	}
 
 	initialSize := int64(0)
+	truncatedSize := int64(0) // bytes truncated and not included in quota
 	if isResume {
 		c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize)
 		minWriteOffset = fileSize
@@ -473,13 +474,14 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
 			}
 		} else {
 			initialSize = fileSize
+			truncatedSize = fileSize
 		}
 	}
 
 	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
+		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
 	baseTransfer.SetFtpMode(c.getFTPMode())
 	t := newTransfer(baseTransfer, w, nil, 0)
 

+ 3 - 3
ftpd/internal_test.go

@@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) {
 		clientContext:  mockCC,
 	}
 	baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	tr := newTransfer(baseTransfer, nil, nil, 0)
 	err = tr.Close()
 	assert.NoError(t, err)
@@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) {
 	r, _, err := pipeat.Pipe()
 	assert.NoError(t, err)
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
-		common.TransferUpload, 0, 0, 0, false, fs)
+		common.TransferUpload, 0, 0, 0, 0, false, fs)
 	tr = newTransfer(baseTransfer, nil, r, 10)
 	pos, err := tr.Seek(10, 0)
 	assert.NoError(t, err)
@@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) {
 	assert.NoError(t, err)
 	pipeWriter := vfs.NewPipeWriter(w)
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
-		common.TransferUpload, 0, 0, 0, false, fs)
+		common.TransferUpload, 0, 0, 0, 0, false, fs)
 	tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
 
 	err = r.Close()

+ 8 - 8
go.mod

@@ -7,8 +7,8 @@ require (
 	github.com/Azure/azure-storage-blob-go v0.14.0
 	github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
 	github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
-	github.com/aws/aws-sdk-go v1.42.35
-	github.com/cockroachdb/cockroach-go/v2 v2.2.5
+	github.com/aws/aws-sdk-go v1.42.37
+	github.com/cockroachdb/cockroach-go/v2 v2.2.6
 	github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
 	github.com/fclairamb/ftpserverlib v0.17.0
 	github.com/fclairamb/go-log v0.2.0
@@ -35,7 +35,7 @@ require (
 	github.com/pires/go-proxyproto v0.6.1
 	github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae
 	github.com/pquerna/otp v1.3.0
-	github.com/prometheus/client_golang v1.11.0
+	github.com/prometheus/client_golang v1.12.0
 	github.com/rs/cors v1.8.2
 	github.com/rs/xid v1.3.0
 	github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50
@@ -62,8 +62,8 @@ require (
 
 require (
 	cloud.google.com/go v0.100.2 // indirect
-	cloud.google.com/go/compute v1.0.0 // indirect
-	cloud.google.com/go/iam v0.1.0 // indirect
+	cloud.google.com/go/compute v1.1.0 // indirect
+	cloud.google.com/go/iam v0.1.1 // indirect
 	github.com/Azure/azure-pipeline-go v0.2.3 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/boombuler/barcode v1.0.1 // indirect
@@ -79,7 +79,7 @@ require (
 	github.com/goccy/go-json v0.9.3 // indirect
 	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
 	github.com/golang/protobuf v1.5.2 // indirect
-	github.com/google/go-cmp v0.5.6 // indirect
+	github.com/google/go-cmp v0.5.7 // indirect
 	github.com/googleapis/gax-go/v2 v2.1.1 // indirect
 	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
@@ -126,10 +126,10 @@ require (
 	golang.org/x/tools v0.1.8 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 // indirect
+	google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
 	google.golang.org/grpc v1.43.0 // indirect
 	google.golang.org/protobuf v1.27.1 // indirect
-	gopkg.in/ini.v1 v1.66.2 // indirect
+	gopkg.in/ini.v1 v1.66.3 // indirect
 	gopkg.in/yaml.v2 v2.4.0 // indirect
 	gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
 )

+ 18 - 13
go.sum

@@ -46,14 +46,14 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
 cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
 cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
 cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
-cloud.google.com/go/compute v1.0.0 h1:SJYBzih8Jj9EUm6IDirxKG0I0AGWduhtb6BmdqWarw4=
-cloud.google.com/go/compute v1.0.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
+cloud.google.com/go/compute v1.1.0 h1:pyPhehLfZ6pVzRgJmXGYvCY4K7WSWRhVw0AwhgVvS84=
+cloud.google.com/go/compute v1.1.0/go.mod h1:2NIffxgWfORSI7EOYMFatGTfjMLnqrOKBEyYb6NoRgA=
 cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
 cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
 cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo=
 cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY=
-cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY=
-cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c=
+cloud.google.com/go/iam v0.1.1 h1:4CapQyNFjiksks1/x7jsvsygFPhihslYk5GptIrlX68=
+cloud.google.com/go/iam v0.1.1/go.mod h1:CKqrcnI/suGpybEHxZ7BMehL0oA4LpdyJdUlTl9jVMw=
 cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0=
 cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c=
 cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE=
@@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI
 github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
 github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
 github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
-github.com/aws/aws-sdk-go v1.42.35 h1:N4N9buNs4YlosI9N0+WYrq8cIZwdgv34yRbxzZlTvFs=
-github.com/aws/aws-sdk-go v1.42.35/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
+github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM=
+github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
 github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
 github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
 github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
@@ -190,8 +190,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
 github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
 github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
 github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
-github.com/cockroachdb/cockroach-go/v2 v2.2.5 h1:tfPdGHO5YpmrpN2ikJZYpaSGgU8WALwwjH3s+msiTQ0=
-github.com/cockroachdb/cockroach-go/v2 v2.2.5/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
+github.com/cockroachdb/cockroach-go/v2 v2.2.6 h1:LTh++UIVvmDBihDo1oYbM8+OruXheusw+ILCONlAm/w=
+github.com/cockroachdb/cockroach-go/v2 v2.2.6/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
 github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
 github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
 github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c=
@@ -343,8 +343,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
 github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
+github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
 github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk=
 github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -649,8 +650,9 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
-github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
+github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg=
+github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -1075,6 +1077,7 @@ google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUb
 google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
 google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw=
 google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo=
+google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM=
 google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ=
 google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg=
 google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@@ -1162,8 +1165,9 @@ google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ6
 google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
-google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 h1:aCsSLXylHWFno0r4S3joLpiaWayvqd2Mn4iSvx4WZZc=
-google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
+google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
+google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
+google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
@@ -1217,8 +1221,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
-gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
 gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
+gopkg.in/ini.v1 v1.66.3 h1:jRskFVxYaMGAMUbN0UZ7niA9gzL9B49DOqE78vg0k3w=
+gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
 gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
 gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
 gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

+ 6 - 5
httpd/file.go

@@ -1,7 +1,6 @@
 package httpd
 
 import (
-	"errors"
 	"io"
 	"sync/atomic"
 
@@ -11,8 +10,6 @@ import (
 	"github.com/drakkan/sftpgo/v2/vfs"
 )
 
-var errTransferAborted = errors.New("transfer aborted")
-
 type httpdFile struct {
 	*common.BaseTransfer
 	writer     io.WriteCloser
@@ -42,7 +39,9 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
 // Read reads the contents to downloads.
 func (f *httpdFile) Read(p []byte) (n int, err error) {
 	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
-		return 0, errTransferAborted
+		err := f.GetAbortError()
+		f.TransferError(err)
+		return 0, err
 	}
 
 	f.Connection.UpdateLastActivity()
@@ -61,7 +60,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
 // Write writes the contents to upload
 func (f *httpdFile) Write(p []byte) (n int, err error) {
 	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
-		return 0, errTransferAborted
+		err := f.GetAbortError()
+		f.TransferError(err)
+		return 0, err
 	}
 
 	f.Connection.UpdateLastActivity()

+ 40 - 6
httpd/handler.go

@@ -6,6 +6,7 @@ import (
 	"os"
 	"path"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"time"
 
@@ -113,7 +114,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io
 	}
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload,
-		0, 0, 0, false, fs)
+		0, 0, 0, 0, false, fs)
 	return newHTTPDFile(baseTransfer, nil, r), nil
 }
 
@@ -190,6 +191,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
 	}
 
 	initialSize := int64(0)
+	truncatedSize := int64(0) // bytes truncated and not included in quota
 	if !isNewFile {
 		if vfs.IsLocalOrSFTPFs(fs) {
 			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
@@ -203,6 +205,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
 			}
 		} else {
 			initialSize = fileSize
+			truncatedSize = fileSize
 		}
 		if maxWriteSize > 0 {
 			maxWriteSize += fileSize
@@ -212,7 +215,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
 	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
+		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
 	return newHTTPDFile(baseTransfer, w, nil), nil
 }
 
@@ -232,15 +235,17 @@ func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttl
 
 type throttledReader struct {
 	bytesRead     int64
-	id            uint64
+	id            int64
 	limit         int64
 	r             io.ReadCloser
 	abortTransfer int32
 	start         time.Time
 	conn          *Connection
+	mu            sync.Mutex
+	errAbort      error
 }
 
-func (t *throttledReader) GetID() uint64 {
+func (t *throttledReader) GetID() int64 {
 	return t.id
 }
 
@@ -252,6 +257,14 @@ func (t *throttledReader) GetSize() int64 {
 	return atomic.LoadInt64(&t.bytesRead)
 }
 
+func (t *throttledReader) GetDownloadedSize() int64 {
+	return 0
+}
+
+func (t *throttledReader) GetUploadedSize() int64 {
+	return atomic.LoadInt64(&t.bytesRead)
+}
+
 func (t *throttledReader) GetVirtualPath() string {
 	return "**reading request body**"
 }
@@ -260,10 +273,31 @@ func (t *throttledReader) GetStartTime() time.Time {
 	return t.start
 }
 
-func (t *throttledReader) SignalClose() {
+func (t *throttledReader) GetAbortError() error {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	if t.errAbort != nil {
+		return t.errAbort
+	}
+	return common.ErrTransferAborted
+}
+
+func (t *throttledReader) SignalClose(err error) {
+	t.mu.Lock()
+	t.errAbort = err
+	t.mu.Unlock()
 	atomic.StoreInt32(&(t.abortTransfer), 1)
 }
 
+func (t *throttledReader) GetTruncatedSize() int64 {
+	return 0
+}
+
+func (t *throttledReader) GetMaxAllowedSize() int64 {
+	return 0
+}
+
 func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) {
 	return 0, vfs.ErrVfsUnsupported
 }
@@ -278,7 +312,7 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
 
 func (t *throttledReader) Read(p []byte) (n int, err error) {
 	if atomic.LoadInt32(&t.abortTransfer) == 1 {
-		return 0, errTransferAborted
+		return 0, t.GetAbortError()
 	}
 
 	t.conn.UpdateLastActivity()

+ 6 - 3
httpd/internal_test.go

@@ -1844,12 +1844,15 @@ func TestThrottledHandler(t *testing.T) {
 	tr := &throttledReader{
 		r: io.NopCloser(bytes.NewBuffer(nil)),
 	}
+	assert.Equal(t, int64(0), tr.GetTruncatedSize())
 	err := tr.Close()
 	assert.NoError(t, err)
 	assert.Empty(t, tr.GetRealFsPath("real path"))
 	assert.False(t, tr.SetTimes("p", time.Now(), time.Now()))
 	_, err = tr.Truncate("", 0)
 	assert.ErrorIs(t, err, vfs.ErrVfsUnsupported)
+	err = tr.GetAbortError()
+	assert.ErrorIs(t, err, common.ErrTransferAborted)
 }
 
 func TestHTTPDFile(t *testing.T) {
@@ -1879,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload,
-		0, 0, 0, false, fs)
+		0, 0, 0, 0, false, fs)
 	httpdFile := newHTTPDFile(baseTransfer, nil, nil)
 	// the file is closed, read should fail
 	buf := make([]byte, 100)
@@ -1899,9 +1902,9 @@ func TestHTTPDFile(t *testing.T) {
 	assert.Error(t, err)
 	assert.Error(t, httpdFile.ErrTransfer)
 	assert.Equal(t, err, httpdFile.ErrTransfer)
-	httpdFile.SignalClose()
+	httpdFile.SignalClose(nil)
 	_, err = httpdFile.Write(nil)
-	assert.ErrorIs(t, err, errTransferAborted)
+	assert.ErrorIs(t, err, common.ErrQuotaExceeded)
 }
 
 func TestChangeUserPwd(t *testing.T) {

+ 5 - 3
sftpd/handler.go

@@ -85,7 +85,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 	}
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload,
-		0, 0, 0, false, fs)
+		0, 0, 0, 0, false, fs)
 	t := newTransfer(baseTransfer, nil, r, nil)
 
 	return t, nil
@@ -364,7 +364,7 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath
 	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
+		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
 	t := newTransfer(baseTransfer, w, nil, errForRead)
 
 	return t, nil
@@ -415,6 +415,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
 	}
 
 	initialSize := int64(0)
+	truncatedSize := int64(0) // bytes truncated and not included in quota
 	if isResume {
 		c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v",
 			filePath, fileSize, pflags.Append)
@@ -436,13 +437,14 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
 			}
 		} else {
 			initialSize = fileSize
+			truncatedSize = fileSize
 		}
 	}
 
 	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
+		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
 	t := newTransfer(baseTransfer, w, nil, errForRead)
 
 	return t, nil

+ 16 - 10
sftpd/internal_test.go

@@ -162,7 +162,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
 	}
 	fs := vfs.NewOsFs("", os.TempDir(), "")
 	conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
-	baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile,
+		common.TransferUpload, 10, 0, 0, 0, false, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 	_, err = transfer.WriteAt([]byte("test"), 0)
 	assert.Error(t, err, "upload with invalid offset must fail")
@@ -193,7 +194,8 @@ func TestReadWriteErrors(t *testing.T) {
 	}
 	fs := vfs.NewOsFs("", os.TempDir(), "")
 	conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
-	baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
+		0, 0, 0, 0, false, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 	err = file.Close()
 	assert.NoError(t, err)
@@ -207,7 +209,8 @@ func TestReadWriteErrors(t *testing.T) {
 
 	r, _, err := pipeat.Pipe()
 	assert.NoError(t, err)
-	baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
+	baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
+		0, 0, 0, 0, false, fs)
 	transfer = newTransfer(baseTransfer, nil, r, nil)
 	err = transfer.Close()
 	assert.NoError(t, err)
@@ -217,7 +220,8 @@ func TestReadWriteErrors(t *testing.T) {
 	r, w, err := pipeat.Pipe()
 	assert.NoError(t, err)
 	pipeWriter := vfs.NewPipeWriter(w)
-	baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
+	baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
+		0, 0, 0, 0, false, fs)
 	transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
 
 	err = r.Close()
@@ -264,7 +268,8 @@ func TestTransferCancelFn(t *testing.T) {
 	}
 	fs := vfs.NewOsFs("", os.TempDir(), "")
 	conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
-	baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload,
+		0, 0, 0, 0, false, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 	errFake := errors.New("fake error, this will trigger cancelFn")
@@ -971,8 +976,8 @@ func TestSystemCommandErrors(t *testing.T) {
 		WriteError:   nil,
 	}
 	sshCmd.connection.channel = &mockSSHChannel
-	baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload,
-		0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "",
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 	destBuff := make([]byte, 65535)
 	dst := bytes.NewBuffer(destBuff)
@@ -1639,7 +1644,7 @@ func TestSCPUploadFiledata(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(),
-		"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
+		"/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 	err = scpCommand.getUploadFileData(2, transfer)
@@ -1724,7 +1729,7 @@ func TestUploadError(t *testing.T) {
 	file, err := os.Create(fileTempName)
 	assert.NoError(t, err)
 	baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(),
-		testfile, common.TransferUpload, 0, 0, 0, true, fs)
+		testfile, common.TransferUpload, 0, 0, 0, 0, true, fs)
 	transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 	errFake := errors.New("fake error")
@@ -1782,7 +1787,8 @@ func TestTransferFailingReader(t *testing.T) {
 
 	r, _, err := pipeat.Pipe()
 	assert.NoError(t, err)
-	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath),
+		common.TransferUpload, 0, 0, 0, 0, false, fs)
 	errRead := errors.New("read is not allowed")
 	tr := newTransfer(baseTransfer, nil, r, errRead)
 	_, err = tr.ReadAt(buf, 0)

+ 4 - 2
sftpd/scp.go

@@ -238,6 +238,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
 	}
 
 	initialSize := int64(0)
+	truncatedSize := int64(0) // bytes truncated and not included in quota
 	if !isNewFile {
 		if vfs.IsLocalOrSFTPFs(fs) {
 			vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
@@ -251,6 +252,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
 			}
 		} else {
 			initialSize = fileSize
+			truncatedSize = initialSize
 		}
 		if maxWriteSize > 0 {
 			maxWriteSize += fileSize
@@ -260,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
 	vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
 
 	baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
+		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
 	t := newTransfer(baseTransfer, w, nil, nil)
 
 	return c.getUploadFileData(sizeToRead, t)
@@ -529,7 +531,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
 	}
 
 	baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	t := newTransfer(baseTransfer, nil, r, nil)
 
 	err = c.sendDownloadFileData(fs, p, stat, t)

+ 3 - 3
sftpd/ssh_cmd.go

@@ -356,7 +356,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
 	go func() {
 		defer stdin.Close()
 		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
-			common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs)
+			common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs)
 		transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 		w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
@@ -369,7 +369,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
 
 	go func() {
 		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
-			common.TransferDownload, 0, 0, 0, false, command.fs)
+			common.TransferDownload, 0, 0, 0, 0, false, command.fs)
 		transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 		w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
@@ -383,7 +383,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
 
 	go func() {
 		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
-			common.TransferDownload, 0, 0, 0, false, command.fs)
+			common.TransferDownload, 0, 0, 0, 0, false, command.fs)
 		transfer := newTransfer(baseTransfer, nil, nil, nil)
 
 		w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)

+ 5 - 5
tests/eventsearcher/go.mod

@@ -4,23 +4,23 @@ go 1.17
 
 require (
 	github.com/hashicorp/go-plugin v1.4.3
-	github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a
+	github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea
 )
 
 require (
 	github.com/fatih/color v1.13.0 // indirect
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/google/go-cmp v0.5.6 // indirect
-	github.com/hashicorp/go-hclog v1.0.0 // indirect
+	github.com/hashicorp/go-hclog v1.1.0 // indirect
 	github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect
 	github.com/mattn/go-colorable v0.1.12 // indirect
 	github.com/mattn/go-isatty v0.0.14 // indirect
 	github.com/mitchellh/go-testing-interface v1.14.1 // indirect
 	github.com/oklog/run v1.1.0 // indirect
-	golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 // indirect
-	golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
+	golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d // indirect
+	golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
 	golang.org/x/text v0.3.7 // indirect
-	google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect
+	google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
 	google.golang.org/grpc v1.43.0 // indirect
 	google.golang.org/protobuf v1.27.1 // indirect
 	gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

+ 10 - 6
tests/eventsearcher/go.sum

@@ -57,8 +57,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
-github.com/hashicorp/go-hclog v1.0.0 h1:bkKf0BeBXcSYa7f5Fyi9gMuQ8gNsxeiNpZjR6VxNZeo=
 github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
+github.com/hashicorp/go-hclog v1.1.0 h1:QsGcniKx5/LuX2eYoeL+Np3UKYPNaN7YKpTh29h8rbw=
+github.com/hashicorp/go-hclog v1.1.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
 github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
 github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
 github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
@@ -85,8 +86,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
-github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a h1:JJc19rE0eW2knPa/KIFYvqyu25CwzKltJ5Cw1kK3o4A=
-github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
+github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k=
+github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
 github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@@ -110,8 +111,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
-golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 h1:+6WJMRLHlD7X7frgp7TUZ36RnQzSf9wVVTNakEp+nqY=
 golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs=
+golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -132,8 +134,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
 golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0=
+golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -156,8 +159,9 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
 google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
 google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
 google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
-google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw=
 google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
+google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
+google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
 google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=

+ 6 - 4
webdavd/handler.go

@@ -149,8 +149,8 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File
 		}
 	}
 
-	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload,
-		0, 0, 0, false, fs)
+	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath,
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 
 	return newWebDavFile(baseTransfer, nil, r), nil
 }
@@ -214,7 +214,7 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re
 	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
+		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
 
 	return newWebDavFile(baseTransfer, w, nil), nil
 }
@@ -252,6 +252,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
 		return nil, c.GetFsError(fs, err)
 	}
 	initialSize := int64(0)
+	truncatedSize := int64(0) // bytes truncated and not included in quota
 	if vfs.IsLocalOrSFTPFs(fs) {
 		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
 		if err == nil {
@@ -264,12 +265,13 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
 		}
 	} else {
 		initialSize = fileSize
+		truncatedSize = fileSize
 	}
 
 	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
 
 	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
-		common.TransferUpload, 0, initialSize, maxWriteSize, false, fs)
+		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs)
 
 	return newWebDavFile(baseTransfer, w, nil), nil
 }

+ 13 - 13
webdavd/internal_test.go

@@ -695,7 +695,7 @@ func TestContentType(t *testing.T) {
 	testFilePath := filepath.Join(user.HomeDir, testFile)
 	ctx := context.Background()
 	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil)
 	err := os.WriteFile(testFilePath, []byte(""), os.ModePerm)
 	assert.NoError(t, err)
@@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
 	}
 	testFilePath := filepath.Join(user.HomeDir, testFile)
 	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferUpload, 0, 0, 0, false, fs)
+		common.TransferUpload, 0, 0, 0, 0, false, fs)
 	davFile := newWebDavFile(baseTransfer, nil, nil)
 	p := make([]byte, 1)
 	_, err := davFile.Read(p)
@@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Read(p)
 	assert.True(t, os.IsNotExist(err))
@@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
 	assert.True(t, os.IsNotExist(err))
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	err = os.WriteFile(testFilePath, []byte(""), os.ModePerm)
 	assert.NoError(t, err)
 	f, err := os.Open(testFilePath)
@@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
 	assert.NoError(t, err)
 	mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r)
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, mockFs)
+		common.TransferDownload, 0, 0, 0, 0, false, mockFs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 
 	writeContent := []byte("content\r\n")
@@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	davFile.writer = f
 	err = davFile.Close()
@@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
 	testFilePath := filepath.Join(user.HomeDir, testFile)
 	testFileContents := []byte("content")
 	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferUpload, 0, 0, 0, false, fs)
+		common.TransferUpload, 0, 0, 0, 0, false, fs)
 	davFile := newWebDavFile(baseTransfer, nil, nil)
 	_, err := davFile.Seek(0, io.SeekStart)
 	assert.EqualError(t, err, common.ErrOpUnsupported.Error())
@@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekCurrent)
 	assert.True(t, os.IsNotExist(err))
@@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
 		assert.NoError(t, err)
 	}
 	baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekStart)
 	assert.Error(t, err)
 	davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	res, err := davFile.Seek(0, io.SeekStart)
 	assert.NoError(t, err)
@@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
 	assert.Nil(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekEnd)
 	assert.True(t, os.IsNotExist(err))
 	davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	davFile.reader = f
 	davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
@@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
 	assert.Equal(t, int64(5), res)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
-		common.TransferDownload, 0, 0, 0, false, fs)
+		common.TransferDownload, 0, 0, 0, 0, false, fs)
 
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)