فهرست منبع

update data transfer quota only if the current IP has some limits

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 3 سال پیش
والد
کامیت
d51adb041e
5فایلهای تغییر یافته به همراه50 افزوده شده و 18 حذف شده
  1. 18 0
      common/protocol_test.go
  2. 10 8
      common/transfer.go
  3. 5 0
      dataprovider/dataprovider.go
  4. 10 3
      webdavd/file.go
  5. 7 7
      webdavd/internal_test.go

+ 18 - 0
common/protocol_test.go

@@ -1002,6 +1002,12 @@ func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) {
 func TestQuotaRenameOverwrite(t *testing.T) {
 	u := getTestUser()
 	u.QuotaFiles = 100
+	u.Filters.DataTransferLimits = []sdk.DataTransferLimit{
+		{
+			Sources:           []string{"10.8.0.0/8"},
+			TotalDataTransfer: 1,
+		},
+	}
 	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
 	assert.NoError(t, err)
 	conn, client, err := getSftpClient(user)
@@ -1013,16 +1019,28 @@ func TestQuotaRenameOverwrite(t *testing.T) {
 		testFileName1 := "test_file1.dat"
 		err = writeSFTPFile(testFileName, testFileSize, client)
 		assert.NoError(t, err)
+		f, err := client.Open(testFileName)
+		assert.NoError(t, err)
+		contents := make([]byte, testFileSize)
+		n, err := io.ReadFull(f, contents)
+		assert.NoError(t, err)
+		assert.Equal(t, int(testFileSize), n)
+		err = f.Close()
+		assert.NoError(t, err)
 		err = writeSFTPFile(testFileName1, testFileSize1, client)
 		assert.NoError(t, err)
 		user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
 		assert.NoError(t, err)
+		assert.Equal(t, int64(0), user.UsedDownloadDataTransfer)
+		assert.Equal(t, int64(0), user.UsedUploadDataTransfer)
 		assert.Equal(t, 2, user.UsedQuotaFiles)
 		assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
 		err = client.Rename(testFileName, testFileName1)
 		assert.NoError(t, err)
 		user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
 		assert.NoError(t, err)
+		assert.Equal(t, int64(0), user.UsedDownloadDataTransfer)
+		assert.Equal(t, int64(0), user.UsedUploadDataTransfer)
 		assert.Equal(t, 1, user.UsedQuotaFiles)
 		assert.Equal(t, testFileSize, user.UsedQuotaSize)
 		err = client.Remove(testFileName1)

+ 10 - 8
common/transfer.go

@@ -153,8 +153,7 @@ func (t *BaseTransfer) HasSizeLimit() bool {
 	if t.MaxWriteSize > 0 {
 		return true
 	}
-	if t.transferQuota.AllowedDLSize > 0 || t.transferQuota.AllowedULSize > 0 ||
-		t.transferQuota.AllowedTotalSize > 0 {
+	if t.transferQuota.HasSizeLimits() {
 		return true
 	}
 
@@ -249,10 +248,11 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
 					sizeDiff := initialSize - size
 					t.MaxWriteSize += sizeDiff
 					metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer)
-					go func(ulSize, dlSize int64, user dataprovider.User) {
-						dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
-					}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
-
+					if t.transferQuota.HasSizeLimits() {
+						go func(ulSize, dlSize int64, user dataprovider.User) {
+							dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
+						}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
+					}
 					atomic.StoreInt64(&t.BytesReceived, 0)
 				}
 				t.Unlock()
@@ -321,8 +321,10 @@ func (t *BaseTransfer) Close() error {
 	}
 	metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
 		t.transferType, t.ErrTransfer)
-	dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
-		atomic.LoadInt64(&t.BytesSent), false)
+	if t.transferQuota.HasSizeLimits() {
+		dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
+			atomic.LoadInt64(&t.BytesSent), false)
+	}
 	if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
 		// if quota is exceeded we try to remove the partial file for uploads to local filesystem
 		err = t.Fs.Remove(t.File.Name(), false)

+ 5 - 0
dataprovider/dataprovider.go

@@ -461,6 +461,11 @@ type TransferQuota struct {
 	AllowedTotalSize int64
 }
 
+// HasSizeLimits returns true if any size limit is set
+func (q *TransferQuota) HasSizeLimits() bool {
+	return q.AllowedDLSize > 0 || q.AllowedULSize > 0 || q.AllowedTotalSize > 0
+}
+
 // HasUploadSpace returns true if there is transfer upload space available
 func (q *TransferQuota) HasUploadSpace() bool {
 	if q.TotalSize <= 0 && q.ULSize <= 0 {

+ 10 - 3
webdavd/file.go

@@ -233,6 +233,15 @@ func (f *webDavFile) updateStatInfo() error {
 	return nil
 }
 
+func (f *webDavFile) updateTransferQuotaOnSeek() {
+	transferQuota := f.GetTransferQuota()
+	if transferQuota.HasSizeLimits() {
+		go func(ulSize, dlSize int64, user dataprovider.User) {
+			dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
+		}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
+	}
+}
+
 // Seek sets the offset for the next Read or Write on the writer to offset,
 // interpreted according to whence: 0 means relative to the origin of the file,
 // 1 means relative to the current offset, and 2 means relative to the end.
@@ -267,9 +276,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
 		startByte := int64(0)
 		atomic.StoreInt64(&f.BytesReceived, 0)
 		atomic.StoreInt64(&f.BytesSent, 0)
-		go func(ulSize, dlSize int64, user dataprovider.User) {
-			dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
-		}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
+		f.updateTransferQuotaOnSeek()
 
 		switch whence {
 		case io.SeekStart:

+ 7 - 7
webdavd/internal_test.go

@@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
 	testFilePath := filepath.Join(user.HomeDir, testFile)
 	testFileContents := []byte("content")
 	baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile := newWebDavFile(baseTransfer, nil, nil)
 	_, err := davFile.Seek(0, io.SeekStart)
 	assert.EqualError(t, err, common.ErrOpUnsupported.Error())
@@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
 	assert.NoError(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekCurrent)
 	assert.True(t, os.IsNotExist(err))
@@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
 		assert.NoError(t, err)
 	}
 	baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekStart)
 	assert.Error(t, err)
 	davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	res, err := davFile.Seek(0, io.SeekStart)
 	assert.NoError(t, err)
@@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
 	assert.Nil(t, err)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	_, err = davFile.Seek(0, io.SeekEnd)
 	assert.True(t, os.IsNotExist(err))
 	davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	davFile.reader = f
 	davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
@@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
 	assert.Equal(t, int64(5), res)
 
 	baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
-		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{})
+		common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
 
 	davFile = newWebDavFile(baseTransfer, nil, nil)
 	davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)