|  | @@ -1,6 +1,7 @@
 | 
	
		
			
				|  |  |  package vfs
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  | +	"bufio"
 | 
	
		
			
				|  |  |  	"errors"
 | 
	
		
			
				|  |  |  	"fmt"
 | 
	
		
			
				|  |  |  	"io"
 | 
	
	
		
			
				|  | @@ -43,8 +44,14 @@ type SFTPFsConfig struct {
 | 
	
		
			
				|  |  |  	// Concurrent reads are safe to use and disabling them will degrade performance.
 | 
	
		
			
				|  |  |  	// Some servers automatically delete files once they are downloaded.
 | 
	
		
			
				|  |  |  	// Using concurrent reads is problematic with such servers.
 | 
	
		
			
				|  |  | -	DisableCouncurrentReads bool     `json:"disable_concurrent_reads,omitempty"`
 | 
	
		
			
				|  |  | -	forbiddenSelfUsernames  []string `json:"-"`
 | 
	
		
			
				|  |  | +	DisableCouncurrentReads bool `json:"disable_concurrent_reads,omitempty"`
 | 
	
		
			
				|  |  | +	// The buffer size (in MB) to use for transfers.
 | 
	
		
			
				|  |  | +	// Buffering could improve performance for high latency networks.
 | 
	
		
			
				|  |  | +	// With buffering enabled upload resume is not supported and a file
 | 
	
		
			
				|  |  | +	// cannot be opened for both reading and writing at the same time
 | 
	
		
			
				|  |  | +	// 0 means disabled.
 | 
	
		
			
				|  |  | +	BufferSize             int64    `json:"buffer_size,omitempty"`
 | 
	
		
			
				|  |  | +	forbiddenSelfUsernames []string `json:"-"`
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool {
 | 
	
	
		
			
				|  | @@ -60,6 +67,9 @@ func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool {
 | 
	
		
			
				|  |  |  	if c.DisableCouncurrentReads != other.DisableCouncurrentReads {
 | 
	
		
			
				|  |  |  		return false
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +	if c.BufferSize != other.BufferSize {
 | 
	
		
			
				|  |  | +		return false
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  	if len(c.Fingerprints) != len(other.Fingerprints) {
 | 
	
		
			
				|  |  |  		return false
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -98,6 +108,21 @@ func (c *SFTPFsConfig) Validate() error {
 | 
	
		
			
				|  |  |  	if c.Username == "" {
 | 
	
		
			
				|  |  |  		return errors.New("username cannot be empty")
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +	if c.BufferSize < 0 || c.BufferSize > 64 {
 | 
	
		
			
				|  |  | +		return errors.New("invalid buffer_size, valid range is 0-64")
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	if err := c.validateCredentials(); err != nil {
 | 
	
		
			
				|  |  | +		return err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	if c.Prefix != "" {
 | 
	
		
			
				|  |  | +		c.Prefix = utils.CleanPath(c.Prefix)
 | 
	
		
			
				|  |  | +	} else {
 | 
	
		
			
				|  |  | +		c.Prefix = "/"
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (c *SFTPFsConfig) validateCredentials() error {
 | 
	
		
			
				|  |  |  	if c.Password.IsEmpty() && c.PrivateKey.IsEmpty() {
 | 
	
		
			
				|  |  |  		return errors.New("credentials cannot be empty")
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -113,11 +138,6 @@ func (c *SFTPFsConfig) Validate() error {
 | 
	
		
			
				|  |  |  	if !c.PrivateKey.IsEmpty() && !c.PrivateKey.IsValidInput() {
 | 
	
		
			
				|  |  |  		return errors.New("invalid private key")
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	if c.Prefix != "" {
 | 
	
		
			
				|  |  | -		c.Prefix = utils.CleanPath(c.Prefix)
 | 
	
		
			
				|  |  | -	} else {
 | 
	
		
			
				|  |  | -		c.Prefix = "/"
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  |  	return nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -143,15 +163,19 @@ type SFTPFs struct {
 | 
	
		
			
				|  |  |  	sync.Mutex
 | 
	
		
			
				|  |  |  	connectionID string
 | 
	
		
			
				|  |  |  	// if not empty this fs is mouted as virtual folder in the specified path
 | 
	
		
			
				|  |  | -	mountPath  string
 | 
	
		
			
				|  |  | -	config     *SFTPFsConfig
 | 
	
		
			
				|  |  | -	sshClient  *ssh.Client
 | 
	
		
			
				|  |  | -	sftpClient *sftp.Client
 | 
	
		
			
				|  |  | -	err        chan error
 | 
	
		
			
				|  |  | +	mountPath    string
 | 
	
		
			
				|  |  | +	localTempDir string
 | 
	
		
			
				|  |  | +	config       *SFTPFsConfig
 | 
	
		
			
				|  |  | +	sshClient    *ssh.Client
 | 
	
		
			
				|  |  | +	sftpClient   *sftp.Client
 | 
	
		
			
				|  |  | +	err          chan error
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // NewSFTPFs returns an SFTPFa object that allows to interact with an SFTP server
 | 
	
		
			
				|  |  | -func NewSFTPFs(connectionID, mountPath string, forbiddenSelfUsernames []string, config SFTPFsConfig) (Fs, error) {
 | 
	
		
			
				|  |  | +func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUsernames []string, config SFTPFsConfig) (Fs, error) {
 | 
	
		
			
				|  |  | +	if localTempDir == "" {
 | 
	
		
			
				|  |  | +		localTempDir = filepath.Clean(os.TempDir())
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  	if err := config.Validate(); err != nil {
 | 
	
		
			
				|  |  |  		return nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -169,6 +193,7 @@ func NewSFTPFs(connectionID, mountPath string, forbiddenSelfUsernames []string,
 | 
	
		
			
				|  |  |  	sftpFs := &SFTPFs{
 | 
	
		
			
				|  |  |  		connectionID: connectionID,
 | 
	
		
			
				|  |  |  		mountPath:    mountPath,
 | 
	
		
			
				|  |  | +		localTempDir: localTempDir,
 | 
	
		
			
				|  |  |  		config:       &config,
 | 
	
		
			
				|  |  |  		err:          make(chan error, 1),
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -220,7 +245,32 @@ func (fs *SFTPFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, f
 | 
	
		
			
				|  |  |  		return nil, nil, nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	f, err := fs.sftpClient.Open(name)
 | 
	
		
			
				|  |  | -	return f, nil, nil, err
 | 
	
		
			
				|  |  | +	if fs.config.BufferSize == 0 {
 | 
	
		
			
				|  |  | +		return f, nil, nil, err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	if offset > 0 {
 | 
	
		
			
				|  |  | +		_, err = f.Seek(offset, io.SeekStart)
 | 
	
		
			
				|  |  | +		if err != nil {
 | 
	
		
			
				|  |  | +			f.Close()
 | 
	
		
			
				|  |  | +			return nil, nil, nil, err
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		f.Close()
 | 
	
		
			
				|  |  | +		return nil, nil, nil, err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	go func() {
 | 
	
		
			
				|  |  | +		br := bufio.NewReaderSize(f, int(fs.config.BufferSize)*1024*1024)
 | 
	
		
			
				|  |  | +		// we don't use io.Copy since bufio.Reader implements io.ReadFrom and
 | 
	
		
			
				|  |  | +		// so it calls the sftp.File ReadFrom method without buffering
 | 
	
		
			
				|  |  | +		n, err := fs.copy(w, br)
 | 
	
		
			
				|  |  | +		w.CloseWithError(err) //nolint:errcheck
 | 
	
		
			
				|  |  | +		f.Close()
 | 
	
		
			
				|  |  | +		fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err)
 | 
	
		
			
				|  |  | +	}()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return nil, r, nil, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // Create creates or opens the named file for writing
 | 
	
	
		
			
				|  | @@ -229,13 +279,51 @@ func (fs *SFTPFs) Create(name string, flag int) (File, *PipeWriter, func(), erro
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		return nil, nil, nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	var f File
 | 
	
		
			
				|  |  | -	if flag == 0 {
 | 
	
		
			
				|  |  | -		f, err = fs.sftpClient.Create(name)
 | 
	
		
			
				|  |  | -	} else {
 | 
	
		
			
				|  |  | -		f, err = fs.sftpClient.OpenFile(name, flag)
 | 
	
		
			
				|  |  | +	if fs.config.BufferSize == 0 {
 | 
	
		
			
				|  |  | +		var f File
 | 
	
		
			
				|  |  | +		if flag == 0 {
 | 
	
		
			
				|  |  | +			f, err = fs.sftpClient.Create(name)
 | 
	
		
			
				|  |  | +		} else {
 | 
	
		
			
				|  |  | +			f, err = fs.sftpClient.OpenFile(name, flag)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		return f, nil, nil, err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	// buffering is enabled
 | 
	
		
			
				|  |  | +	f, err := fs.sftpClient.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		return nil, nil, nil, err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	r, w, err := pipeat.PipeInDir(fs.localTempDir)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		f.Close()
 | 
	
		
			
				|  |  | +		return nil, nil, nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	return f, nil, nil, err
 | 
	
		
			
				|  |  | +	p := NewPipeWriter(w)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	go func() {
 | 
	
		
			
				|  |  | +		bw := bufio.NewWriterSize(f, int(fs.config.BufferSize)*1024*1024)
 | 
	
		
			
				|  |  | +		// we don't use io.Copy since bufio.Writer implements io.WriterTo and
 | 
	
		
			
				|  |  | +		// so it calls the sftp.File WriteTo method without buffering
 | 
	
		
			
				|  |  | +		n, err := fs.copy(bw, r)
 | 
	
		
			
				|  |  | +		errFlush := bw.Flush()
 | 
	
		
			
				|  |  | +		if err == nil && errFlush != nil {
 | 
	
		
			
				|  |  | +			err = errFlush
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		errClose := f.Close()
 | 
	
		
			
				|  |  | +		if err == nil && errClose != nil {
 | 
	
		
			
				|  |  | +			err = errClose
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		r.CloseWithError(err) //nolint:errcheck
 | 
	
		
			
				|  |  | +		var errTruncate error
 | 
	
		
			
				|  |  | +		if err != nil {
 | 
	
		
			
				|  |  | +			errTruncate = f.Truncate(n)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		p.Done(err)
 | 
	
		
			
				|  |  | +		fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v err truncate: %v",
 | 
	
		
			
				|  |  | +			name, n, err, errTruncate)
 | 
	
		
			
				|  |  | +	}()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return nil, p, nil, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // Rename renames (moves) source to target.
 | 
	
	
		
			
				|  | @@ -340,14 +428,14 @@ func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) {
 | 
	
		
			
				|  |  |  	return result, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -// IsUploadResumeSupported returns true if upload resume is supported.
 | 
	
		
			
				|  |  | -func (*SFTPFs) IsUploadResumeSupported() bool {
 | 
	
		
			
				|  |  | -	return true
 | 
	
		
			
				|  |  | +// IsUploadResumeSupported returns true if resuming uploads is supported.
 | 
	
		
			
				|  |  | +func (fs *SFTPFs) IsUploadResumeSupported() bool {
 | 
	
		
			
				|  |  | +	return fs.config.BufferSize == 0
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // IsAtomicUploadSupported returns true if atomic upload is supported.
 | 
	
		
			
				|  |  | -func (*SFTPFs) IsAtomicUploadSupported() bool {
 | 
	
		
			
				|  |  | -	return true
 | 
	
		
			
				|  |  | +func (fs *SFTPFs) IsAtomicUploadSupported() bool {
 | 
	
		
			
				|  |  | +	return fs.config.BufferSize == 0
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // IsNotExist returns a boolean indicating whether the error is known to
 | 
	
	
		
			
				|  | @@ -372,6 +460,11 @@ func (*SFTPFs) IsNotSupported(err error) bool {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // CheckRootPath creates the specified local root directory if it does not exists
 | 
	
		
			
				|  |  |  func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool {
 | 
	
		
			
				|  |  | +	if fs.config.BufferSize > 0 {
 | 
	
		
			
				|  |  | +		// we need a local directory for temporary files
 | 
	
		
			
				|  |  | +		osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "")
 | 
	
		
			
				|  |  | +		osFs.CheckRootPath(username, uid, gid)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  	if fs.config.Prefix == "/" {
 | 
	
		
			
				|  |  |  		return true
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -595,6 +688,38 @@ func (fs *SFTPFs) Close() error {
 | 
	
		
			
				|  |  |  	return sshErr
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func (fs *SFTPFs) copy(dst io.Writer, src io.Reader) (written int64, err error) {
 | 
	
		
			
				|  |  | +	buf := make([]byte, 32768)
 | 
	
		
			
				|  |  | +	for {
 | 
	
		
			
				|  |  | +		nr, er := src.Read(buf)
 | 
	
		
			
				|  |  | +		if nr > 0 {
 | 
	
		
			
				|  |  | +			nw, ew := dst.Write(buf[0:nr])
 | 
	
		
			
				|  |  | +			if nw < 0 || nr < nw {
 | 
	
		
			
				|  |  | +				nw = 0
 | 
	
		
			
				|  |  | +				if ew == nil {
 | 
	
		
			
				|  |  | +					ew = errors.New("invalid write")
 | 
	
		
			
				|  |  | +				}
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			written += int64(nw)
 | 
	
		
			
				|  |  | +			if ew != nil {
 | 
	
		
			
				|  |  | +				err = ew
 | 
	
		
			
				|  |  | +				break
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			if nr != nw {
 | 
	
		
			
				|  |  | +				err = io.ErrShortWrite
 | 
	
		
			
				|  |  | +				break
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		if er != nil {
 | 
	
		
			
				|  |  | +			if er != io.EOF {
 | 
	
		
			
				|  |  | +				err = er
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			break
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return written, err
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (fs *SFTPFs) checkConnection() error {
 | 
	
		
			
				|  |  |  	err := fs.closed()
 | 
	
		
			
				|  |  |  	if err == nil {
 | 
	
	
		
			
				|  | @@ -659,6 +784,11 @@ func (fs *SFTPFs) createConnection() error {
 | 
	
		
			
				|  |  |  		opt := sftp.UseConcurrentReads(false)
 | 
	
		
			
				|  |  |  		opt(fs.sftpClient) //nolint:errcheck
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +	if fs.config.BufferSize > 0 {
 | 
	
		
			
				|  |  | +		fsLog(fs, logger.LevelDebug, "enabling concurrent writes")
 | 
	
		
			
				|  |  | +		opt := sftp.UseConcurrentWrites(true)
 | 
	
		
			
				|  |  | +		opt(fs.sftpClient) //nolint:errcheck
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  	go fs.wait()
 | 
	
		
			
				|  |  |  	return nil
 | 
	
		
			
				|  |  |  }
 |