Browse Source

S3: don't use manager for uploads

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 4 months ago
parent
commit
ff5ea7cd40
3 changed files with 255 additions and 106 deletions
  1. 5 75
      internal/vfs/azblobfs.go
  2. 179 31
      internal/vfs/s3fs.go
  3. 71 0
      internal/vfs/vfs.go

+ 5 - 75
internal/vfs/azblobfs.go

@@ -946,6 +946,8 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *b
 	guard := make(chan struct{}, fs.config.DownloadConcurrency)
 	blockCtxTimeout := time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute
 	pool := newBufferAllocator(int(partSize))
+	defer pool.free()
+
 	finished := false
 	var wg sync.WaitGroup
 	var errOnce sync.Once
@@ -999,7 +1001,6 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *b
 
 	wg.Wait()
 	close(guard)
-	pool.free()
 
 	return poolError
 }
@@ -1014,6 +1015,8 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 	// sync.Pool seems to use a lot of memory so prefer our own, very simple, allocator
 	// we only need to recycle few byte slices
 	pool := newBufferAllocator(int(partSize))
+	defer pool.free()
+
 	finished := false
 	var blocks []string
 	var wg sync.WaitGroup
@@ -1027,7 +1030,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 	for part := 0; !finished; part++ {
 		buf := pool.getBuffer()
 
-		n, err := fs.readFill(reader, buf)
+		n, err := readFill(reader, buf)
 		if err == io.EOF {
 			// read finished, if n > 0 we need to process the last data chunck
 			if n == 0 {
@@ -1037,7 +1040,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 			finished = true
 		} else if err != nil {
 			pool.releaseBuffer(buf)
-			pool.free()
 			return err
 		}
 
@@ -1046,7 +1048,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 		generatedUUID, err := uuid.NewRandom()
 		if err != nil {
 			pool.releaseBuffer(buf)
-			pool.free()
 			return fmt.Errorf("unable to generate block ID: %w", err)
 		}
 		blockID := base64.StdEncoding.EncodeToString([]byte(generatedUUID.String()))
@@ -1087,7 +1088,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 
 	wg.Wait()
 	close(guard)
-	pool.free()
 
 	if poolError != nil {
 		return poolError
@@ -1117,16 +1117,6 @@ func (*AzureBlobFs) writeAtFull(w io.WriterAt, buf []byte, offset int64, count i
 	return written, nil
 }
 
-// copied from rclone
-func (*AzureBlobFs) readFill(r io.Reader, buf []byte) (n int, err error) {
-	var nn int
-	for n < len(buf) && err == nil {
-		nn, err = r.Read(buf[n:])
-		n += nn
-	}
-	return n, err
-}
-
 func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions {
 	copyOptions := &blob.StartCopyFromURLOptions{}
 	if fs.config.AccessTier != "" {
@@ -1187,66 +1177,6 @@ func getAzContainerClientOptions() *container.ClientOptions {
 	}
 }
 
-type bytesReaderWrapper struct {
-	*bytes.Reader
-}
-
-func (b *bytesReaderWrapper) Close() error {
-	return nil
-}
-
-type bufferAllocator struct {
-	sync.Mutex
-	available  [][]byte
-	bufferSize int
-	finalized  bool
-}
-
-func newBufferAllocator(size int) *bufferAllocator {
-	return &bufferAllocator{
-		bufferSize: size,
-		finalized:  false,
-	}
-}
-
-func (b *bufferAllocator) getBuffer() []byte {
-	b.Lock()
-	defer b.Unlock()
-
-	if len(b.available) > 0 {
-		var result []byte
-
-		truncLength := len(b.available) - 1
-		result = b.available[truncLength]
-
-		b.available[truncLength] = nil
-		b.available = b.available[:truncLength]
-
-		return result
-	}
-
-	return make([]byte, b.bufferSize)
-}
-
-func (b *bufferAllocator) releaseBuffer(buf []byte) {
-	b.Lock()
-	defer b.Unlock()
-
-	if b.finalized || len(buf) != b.bufferSize {
-		return
-	}
-
-	b.available = append(b.available, buf)
-}
-
-func (b *bufferAllocator) free() {
-	b.Lock()
-	defer b.Unlock()
-
-	b.available = nil
-	b.finalized = true
-}
-
 type azureBlobDirLister struct {
 	baseDirLister
 	paginator     *runtime.Pager[container.ListBlobsHierarchyResponse]

+ 179 - 31
internal/vfs/s3fs.go

@@ -17,6 +17,7 @@
 package vfs
 
 import (
+	"bytes"
 	"context"
 	"crypto/md5"
 	"crypto/sha256"
@@ -255,7 +256,7 @@ func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error
 
 	var streamRange *string
 	if offset > 0 {
-		streamRange = aws.String(fmt.Sprintf("bytes=%v-", offset))
+		streamRange = aws.String(fmt.Sprintf("bytes=%d-", offset))
 	}
 
 	go func() {
@@ -295,16 +296,6 @@ func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(),
 		p = NewPipeWriter(w)
 	}
 	ctx, cancelFn := context.WithCancel(context.Background())
-	uploader := manager.NewUploader(fs.svc, func(u *manager.Uploader) {
-		u.Concurrency = fs.config.UploadConcurrency
-		u.PartSize = fs.config.UploadPartSize
-		if fs.config.UploadPartMaxTime > 0 {
-			u.ClientOptions = append(u.ClientOptions, func(o *s3.Options) {
-				o.HTTPClient = getAWSHTTPClient(fs.config.UploadPartMaxTime, 100*time.Millisecond,
-					fs.config.SkipTLSVerify)
-			})
-		}
-	})
 
 	go func() {
 		defer cancelFn()
@@ -315,17 +306,7 @@ func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(),
 		} else {
 			contentType = mime.TypeByExtension(path.Ext(name))
 		}
-		_, err := uploader.Upload(ctx, &s3.PutObjectInput{
-			Bucket:               aws.String(fs.config.Bucket),
-			Key:                  aws.String(name),
-			Body:                 r,
-			ACL:                  types.ObjectCannedACL(fs.config.ACL),
-			StorageClass:         types.StorageClass(fs.config.StorageClass),
-			ContentType:          util.NilIfEmpty(contentType),
-			SSECustomerKey:       util.NilIfEmpty(fs.sseCustomerKey),
-			SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
-			SSECustomerKeyMD5:    util.NilIfEmpty(fs.sseCustomerKeyMD5),
-		})
+		err := fs.handleMultipartUpload(ctx, r, name, contentType)
 		r.CloseWithError(err) //nolint:errcheck
 		p.Done(err)
 		fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %d, err: %+v",
@@ -834,6 +815,181 @@ func (fs *S3Fs) hasContents(name string) (bool, error) {
 	return false, nil
 }
 
+func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) {
+	ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout))
+	defer cancelFn()
+
+	res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
+		Bucket:               aws.String(fs.config.Bucket),
+		Key:                  aws.String(name),
+		StorageClass:         types.StorageClass(fs.config.StorageClass),
+		ACL:                  types.ObjectCannedACL(fs.config.ACL),
+		ContentType:          util.NilIfEmpty(contentType),
+		SSECustomerKey:       util.NilIfEmpty(fs.sseCustomerKey),
+		SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
+		SSECustomerKeyMD5:    util.NilIfEmpty(fs.sseCustomerKeyMD5),
+	})
+	if err != nil {
+		return "", fmt.Errorf("unable to create multipart upload request: %w", err)
+	}
+	uploadID := util.GetStringFromPointer(res.UploadId)
+	if uploadID == "" {
+		return "", errors.New("unable to get multipart upload ID")
+	}
+	return uploadID, nil
+}
+
+func (fs *S3Fs) uploadPart(ctx context.Context, name, uploadID string, partNumber int32, data []byte) (*string, error) {
+	timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute
+	if fs.config.UploadPartMaxTime > 0 {
+		timeout = time.Duration(fs.config.UploadPartMaxTime)
+	}
+	ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout))
+	defer cancelFn()
+
+	resp, err := fs.svc.UploadPart(ctx, &s3.UploadPartInput{
+		Bucket:               aws.String(fs.config.Bucket),
+		Key:                  aws.String(name),
+		PartNumber:           &partNumber,
+		UploadId:             aws.String(uploadID),
+		Body:                 bytes.NewReader(data),
+		SSECustomerKey:       util.NilIfEmpty(fs.sseCustomerKey),
+		SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
+		SSECustomerKeyMD5:    util.NilIfEmpty(fs.sseCustomerKeyMD5),
+	})
+	if err != nil {
+		return nil, fmt.Errorf("unable to upload part number %d: %w", partNumber, err)
+	}
+	return resp.ETag, nil
+}
+
+func (fs *S3Fs) completeMultipartUpload(ctx context.Context, name, uploadID string, completedParts []types.CompletedPart) error {
+	ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout))
+	defer cancelFn()
+
+	_, err := fs.svc.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
+		Bucket:   aws.String(fs.config.Bucket),
+		Key:      aws.String(name),
+		UploadId: aws.String(uploadID),
+		MultipartUpload: &types.CompletedMultipartUpload{
+			Parts: completedParts,
+		},
+	})
+	return err
+}
+
+func (fs *S3Fs) abortMultipartUpload(name, uploadID string) error {
+	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
+	defer cancelFn()
+
+	_, err := fs.svc.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
+		Bucket:   aws.String(fs.config.Bucket),
+		Key:      aws.String(name),
+		UploadId: aws.String(uploadID),
+	})
+	return err
+}
+
+func (fs *S3Fs) handleMultipartUpload(ctx context.Context, reader io.Reader, name, contentType string) error {
+	uploadID, err := fs.initiateMultipartUpload(ctx, name, contentType)
+	if err != nil {
+		return err
+	}
+	guard := make(chan struct{}, fs.config.UploadConcurrency)
+	finished := false
+	var partMutex sync.Mutex
+	var completedParts []types.CompletedPart
+	var wg sync.WaitGroup
+	var hasError atomic.Bool
+	var poolErr error
+	var errOnce sync.Once
+	var partNumber int32
+
+	pool := newBufferAllocator(int(fs.config.UploadPartSize))
+	defer pool.free()
+
+	poolCtx, poolCancel := context.WithCancel(ctx)
+	defer poolCancel()
+
+	finalizeFailedUpload := func(err error) {
+		fsLog(fs, logger.LevelError, "finalize failed multipart upload after error: %v", err)
+		hasError.Store(true)
+		poolErr = err
+		poolCancel()
+		if abortErr := fs.abortMultipartUpload(name, uploadID); abortErr != nil {
+			fsLog(fs, logger.LevelError, "unable to abort multipart upload: %+v", abortErr)
+		}
+	}
+
+	for partNumber = 1; !finished; partNumber++ {
+		buf := pool.getBuffer()
+
+		n, err := readFill(reader, buf)
+		if err == io.EOF {
+			if n == 0 && partNumber > 1 {
+				pool.releaseBuffer(buf)
+				break
+			}
+			finished = true
+		} else if err != nil {
+			pool.releaseBuffer(buf)
+			errOnce.Do(func() {
+				finalizeFailedUpload(err)
+			})
+			return err
+		}
+		guard <- struct{}{}
+		if hasError.Load() {
+			fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", partNumber)
+			pool.releaseBuffer(buf)
+			break
+		}
+
+		wg.Add(1)
+		go func(partNum int32, buf []byte, bytesRead int) {
+			defer func() {
+				pool.releaseBuffer(buf)
+				<-guard
+				wg.Done()
+			}()
+
+			etag, err := fs.uploadPart(poolCtx, name, uploadID, partNum, buf[:bytesRead])
+			if err != nil {
+				errOnce.Do(func() {
+					finalizeFailedUpload(err)
+				})
+				return
+			}
+			partMutex.Lock()
+			completedParts = append(completedParts, types.CompletedPart{
+				PartNumber: &partNum,
+				ETag:       etag,
+			})
+			partMutex.Unlock()
+		}(partNumber, buf, n)
+	}
+
+	wg.Wait()
+	close(guard)
+
+	if poolErr != nil {
+		return poolErr
+	}
+
+	sort.Slice(completedParts, func(i, j int) bool {
+		getPartNumber := func(number *int32) int32 {
+			if number == nil {
+				return 0
+			}
+			return *number
+		}
+
+		return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber)
+	})
+
+	return fs.completeMultipartUpload(ctx, name, uploadID, completedParts)
+}
+
 func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int64) error {
 	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
 	defer cancelFn()
@@ -921,15 +1077,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
 					copyError = fmt.Errorf("error copying part number %d: %w", partNum, err)
 					opCancel()
 
-					abortCtx, abortCancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
-					defer abortCancelFn()
-
-					_, errAbort := fs.svc.AbortMultipartUpload(abortCtx, &s3.AbortMultipartUploadInput{
-						Bucket:   aws.String(fs.config.Bucket),
-						Key:      aws.String(target),
-						UploadId: aws.String(uploadID),
-					})
-					if errAbort != nil {
+					if errAbort := fs.abortMultipartUpload(target, uploadID); errAbort != nil {
 						fsLog(fs, logger.LevelError, "unable to abort multipart copy: %+v", errAbort)
 					}
 				})

+ 71 - 0
internal/vfs/vfs.go

@@ -16,6 +16,7 @@
 package vfs
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"io"
@@ -1271,6 +1272,76 @@ func doRecursiveRename(fs Fs, source, target string,
 	}
 }
 
+// copied from rclone
+func readFill(r io.Reader, buf []byte) (n int, err error) {
+	var nn int
+	for n < len(buf) && err == nil {
+		nn, err = r.Read(buf[n:])
+		n += nn
+	}
+	return n, err
+}
+
+type bytesReaderWrapper struct {
+	*bytes.Reader
+}
+
+func (b *bytesReaderWrapper) Close() error {
+	return nil
+}
+
+type bufferAllocator struct {
+	sync.Mutex
+	available  [][]byte
+	bufferSize int
+	finalized  bool
+}
+
+func newBufferAllocator(size int) *bufferAllocator {
+	return &bufferAllocator{
+		bufferSize: size,
+		finalized:  false,
+	}
+}
+
+func (b *bufferAllocator) getBuffer() []byte {
+	b.Lock()
+	defer b.Unlock()
+
+	if len(b.available) > 0 {
+		var result []byte
+
+		truncLength := len(b.available) - 1
+		result = b.available[truncLength]
+
+		b.available[truncLength] = nil
+		b.available = b.available[:truncLength]
+
+		return result
+	}
+
+	return make([]byte, b.bufferSize)
+}
+
+func (b *bufferAllocator) releaseBuffer(buf []byte) {
+	b.Lock()
+	defer b.Unlock()
+
+	if b.finalized || len(buf) != b.bufferSize {
+		return
+	}
+
+	b.available = append(b.available, buf)
+}
+
+func (b *bufferAllocator) free() {
+	b.Lock()
+	defer b.Unlock()
+
+	b.available = nil
+	b.finalized = true
+}
+
 func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) {
 	logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...)
 }