Browse Source

WebDAV: try to preserve the lock fs as much as possible

Nicola Murino 4 years ago
parent
commit
9ad750da54

+ 142 - 0
dataprovider/cacheduser.go

@@ -0,0 +1,142 @@
+package dataprovider
+
+import (
+	"sync"
+	"time"
+
+	"golang.org/x/net/webdav"
+
+	"github.com/drakkan/sftpgo/utils"
+)
+
+var (
+	webDAVUsersCache *usersCache
+)
+
+func init() {
+	webDAVUsersCache = &usersCache{
+		users: map[string]CachedUser{},
+	}
+}
+
+// InitializeWebDAVUserCache initializes the cache for webdav users
+func InitializeWebDAVUserCache(maxSize int) {
+	webDAVUsersCache = &usersCache{
+		users:   map[string]CachedUser{},
+		maxSize: maxSize,
+	}
+}
+
+// CachedUser adds fields useful for caching to a SFTPGo user
+type CachedUser struct {
+	User       User
+	Expiration time.Time
+	Password   string
+	LockSystem webdav.LockSystem
+}
+
+// IsExpired returns true if the cached user is expired
+func (c *CachedUser) IsExpired() bool {
+	if c.Expiration.IsZero() {
+		return false
+	}
+	return c.Expiration.Before(time.Now())
+}
+
+type usersCache struct {
+	sync.RWMutex
+	users   map[string]CachedUser
+	maxSize int
+}
+
+func (cache *usersCache) updateLastLogin(username string) {
+	cache.Lock()
+	defer cache.Unlock()
+
+	if cachedUser, ok := cache.users[username]; ok {
+		cachedUser.User.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now())
+		cache.users[username] = cachedUser
+	}
+}
+
+// swapWebDAVUser updates an existing cached user with the specified one
+// preserving the lock fs if possible
+func (cache *usersCache) swap(user *User) {
+	cache.Lock()
+	defer cache.Unlock()
+
+	if cachedUser, ok := cache.users[user.Username]; ok {
+		if cachedUser.User.Password != user.Password {
+			// the password changed, the cached user is no longer valid
+			delete(cache.users, user.Username)
+			return
+		}
+		if cachedUser.User.isFsEqual(user) {
+			// the updated user has the same fs as the cached one, we can preserve the lock filesystem
+			cachedUser.User = *user
+			cache.users[user.Username] = cachedUser
+		} else {
+			// filesystem changed, the cached user is no longer valid
+			delete(cache.users, user.Username)
+		}
+	}
+}
+
+func (cache *usersCache) add(cachedUser *CachedUser) {
+	cache.Lock()
+	defer cache.Unlock()
+
+	if cache.maxSize > 0 && len(cache.users) >= cache.maxSize {
+		var userToRemove string
+		var expirationTime time.Time
+
+		for k, v := range cache.users {
+			if userToRemove == "" {
+				userToRemove = k
+				expirationTime = v.Expiration
+				continue
+			}
+			expireTime := v.Expiration
+			if !expireTime.IsZero() && expireTime.Before(expirationTime) {
+				userToRemove = k
+				expirationTime = expireTime
+			}
+		}
+
+		delete(cache.users, userToRemove)
+	}
+
+	if cachedUser.User.Username != "" {
+		cache.users[cachedUser.User.Username] = *cachedUser
+	}
+}
+
+func (cache *usersCache) remove(username string) {
+	cache.Lock()
+	defer cache.Unlock()
+
+	delete(cache.users, username)
+}
+
+func (cache *usersCache) get(username string) (*CachedUser, bool) {
+	cache.RLock()
+	defer cache.RUnlock()
+
+	cachedUser, ok := cache.users[username]
+	return &cachedUser, ok
+}
+
+// CacheWebDAVUser add a user to the WebDAV cache
+func CacheWebDAVUser(cachedUser *CachedUser) {
+	webDAVUsersCache.add(cachedUser)
+}
+
+// GetCachedWebDAVUser returns a previously cached WebDAV user
+func GetCachedWebDAVUser(username string) (*CachedUser, bool) {
+	return webDAVUsersCache.get(username)
+}
+
+// RemoveCachedWebDAVUser removes a cached WebDAV user
+func RemoveCachedWebDAVUser(username string) {
+	webDAVUsersCache.remove(username)
+}

+ 19 - 62
dataprovider/dataprovider.go

@@ -111,7 +111,6 @@ var (
 	// ErrInvalidCredentials defines the error to return if the supplied credentials are invalid
 	ErrInvalidCredentials = errors.New("invalid credentials")
 	validTLSUsernames     = []string{string(TLSUsernameNone), string(TLSUsernameCN)}
-	webDAVUsersCache      sync.Map
 	config                Config
 	provider              Provider
 	sqlPlaceholders       []string
@@ -750,7 +749,7 @@ func UpdateLastLogin(user *User) error {
 	if diff < 0 || diff > lastLoginMinDelay {
 		err := provider.updateLastLogin(user.Username)
 		if err == nil {
-			updateWebDavCachedUserLastLogin(user.Username)
+			webDAVUsersCache.updateLastLogin(user.Username)
 		}
 		return err
 	}
@@ -841,7 +840,7 @@ func AddUser(user *User) error {
 func UpdateUser(user *User) error {
 	err := provider.updateUser(user)
 	if err == nil {
-		RemoveCachedWebDAVUser(user.Username)
+		webDAVUsersCache.swap(user)
 		executeAction(operationUpdate, user)
 	}
 	return err
@@ -2190,6 +2189,9 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
 		err = provider.addUser(&u)
 	} else {
 		err = provider.updateUser(&u)
+		if err == nil {
+			webDAVUsersCache.swap(&u)
+		}
 	}
 	if err != nil {
 		return u, err
@@ -2328,6 +2330,15 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
 	return cmd.Output()
 }
 
+func updateUserFromExtAuthResponse(user *User, password, pkey string) {
+	if password != "" {
+		user.Password = password
+	}
+	if pkey != "" && !utils.IsStringPrefixInSlice(pkey, user.PublicKeys) {
+		user.PublicKeys = append(user.PublicKeys, pkey)
+	}
+}
+
 func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive, ip, protocol string, tlsCert *x509.Certificate) (User, error) {
 	var user User
 
@@ -2358,15 +2369,11 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
 	if err != nil {
 		return user, fmt.Errorf("invalid external auth response: %v", err)
 	}
+	// an empty username means authentication failure
 	if user.Username == "" {
 		return user, ErrInvalidCredentials
 	}
-	if password != "" {
-		user.Password = password
-	}
-	if pkey != "" && !utils.IsStringPrefixInSlice(pkey, user.PublicKeys) {
-		user.PublicKeys = append(user.PublicKeys, pkey)
-	}
+	updateUserFromExtAuthResponse(&user, password, pkey)
 	// some users want to map multiple login usernames with a single SFTPGo account
 	// for example an SFTP user logins using "user1" or "user2" and the external auth
 	// returns "user" in both cases, so we use the username returned from
@@ -2381,6 +2388,9 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
 		user.LastQuotaUpdate = u.LastQuotaUpdate
 		user.LastLogin = u.LastLogin
 		err = provider.updateUser(&user)
+		if err == nil {
+			webDAVUsersCache.swap(&user)
+		}
 		return user, err
 	}
 	err = provider.addUser(&user)
@@ -2485,56 +2495,3 @@ func executeAction(operation string, user *User) {
 		}
 	}()
 }
-
-func updateWebDavCachedUserLastLogin(username string) {
-	result, ok := webDAVUsersCache.Load(username)
-	if ok {
-		cachedUser := result.(*CachedUser)
-		cachedUser.User.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now())
-		webDAVUsersCache.Store(cachedUser.User.Username, cachedUser)
-	}
-}
-
-// CacheWebDAVUser add a user to the WebDAV cache
-func CacheWebDAVUser(cachedUser *CachedUser, maxSize int) {
-	if maxSize > 0 {
-		var cacheSize int
-		var userToRemove string
-		var expirationTime time.Time
-
-		webDAVUsersCache.Range(func(k, v interface{}) bool {
-			cacheSize++
-			if userToRemove == "" {
-				userToRemove = k.(string)
-				expirationTime = v.(*CachedUser).Expiration
-				return true
-			}
-			expireTime := v.(*CachedUser).Expiration
-			if !expireTime.IsZero() && expireTime.Before(expirationTime) {
-				userToRemove = k.(string)
-				expirationTime = expireTime
-			}
-			return true
-		})
-
-		if cacheSize >= maxSize {
-			RemoveCachedWebDAVUser(userToRemove)
-		}
-	}
-
-	if cachedUser.User.Username != "" {
-		webDAVUsersCache.Store(cachedUser.User.Username, cachedUser)
-	}
-}
-
-// GetCachedWebDAVUser returns a previously cached WebDAV user
-func GetCachedWebDAVUser(username string) (interface{}, bool) {
-	return webDAVUsersCache.Load(username)
-}
-
-// RemoveCachedWebDAVUser removes a cached WebDAV user
-func RemoveCachedWebDAVUser(username string) {
-	if username != "" {
-		webDAVUsersCache.Delete(username)
-	}
-}

+ 33 - 18
dataprovider/user.go

@@ -13,8 +13,6 @@ import (
 	"strings"
 	"time"
 
-	"golang.org/x/net/webdav"
-
 	"github.com/drakkan/sftpgo/kms"
 	"github.com/drakkan/sftpgo/logger"
 	"github.com/drakkan/sftpgo/utils"
@@ -75,22 +73,6 @@ var (
 	errNoMatchingVirtualFolder = errors.New("no matching virtual folder found")
 )
 
-// CachedUser adds fields useful for caching to a SFTPGo user
-type CachedUser struct {
-	User       User
-	Expiration time.Time
-	Password   string
-	LockSystem webdav.LockSystem
-}
-
-// IsExpired returns true if the cached user is expired
-func (c *CachedUser) IsExpired() bool {
-	if c.Expiration.IsZero() {
-		return false
-	}
-	return c.Expiration.Before(time.Now())
-}
-
 // ExtensionsFilter defines filters based on file extensions.
 // These restrictions do not apply to files listing for performance reasons, so
 // a denied file cannot be downloaded/overwritten/renamed but will still be
@@ -279,6 +261,39 @@ func (u *User) CheckFsRoot(connectionID string) error {
 	return nil
 }
 
+// isFsEqual returns true if the fs has the same configuration
+func (u *User) isFsEqual(other *User) bool {
+	if u.FsConfig.Provider == vfs.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() {
+		return false
+	}
+	if !u.FsConfig.IsEqual(&other.FsConfig) {
+		return false
+	}
+	if len(u.VirtualFolders) != len(other.VirtualFolders) {
+		return false
+	}
+	for idx := range u.VirtualFolders {
+		f := &u.VirtualFolders[idx]
+		found := false
+		for idx1 := range other.VirtualFolders {
+			f1 := &other.VirtualFolders[idx1]
+			if f.VirtualPath == f1.VirtualPath {
+				found = true
+				if f.FsConfig.Provider == vfs.LocalFilesystemProvider && f.MappedPath != f1.MappedPath {
+					return false
+				}
+				if !f.FsConfig.IsEqual(&f1.FsConfig) {
+					return false
+				}
+			}
+		}
+		if !found {
+			return false
+		}
+	}
+	return true
+}
+
 // hideConfidentialData hides user confidential data
 func (u *User) hideConfidentialData() {
 	u.Password = ""

+ 3 - 3
docs/external-auth.md

@@ -18,7 +18,7 @@ The program can inspect the SFTPGo user, if it exists, using the `SFTPGO_AUTHD_U
 The program must write, on its standard output:
 
 - a valid SFTPGo user serialized as JSON if the authentication succeeds. The user will be added/updated within the defined data provider
-- an empty string, or no response at all, if authentication succeeds and the existing SFTPGo user does not need to be updated
+- an empty string, or no response at all, if authentication succeeds and the existing SFTPGo user does not need to be updated. Please note that in versions 2.0.x and earlier an empty response was interpreted as an authentication error
 - a user with an empty username if the authentication fails
 
 If the hook is an HTTP URL then it will be invoked as HTTP POST. The request body will contain a JSON serialized struct with the following fields:
@@ -35,9 +35,9 @@ If the hook is an HTTP URL then it will be invoked as HTTP POST. The request bod
 If authentication succeeds the HTTP response code must be 200 and the response body can be:
 
 - a valid SFTPGo user serialized as JSON. The user will be added/updated within the defined data provider
-- empty, the existing SFTPGo user does not need to be updated
+- empty, the existing SFTPGo user does not need to be updated. Please note that in versions 2.0.x and earlier an empty response was interpreted as an authentication error
 
-If the authentication fails the HTTP response code must be != 200.
+If the authentication fails the HTTP response code must be != 200 or the returned SFTPGo user must have an empty username.
 
 Actions defined for users added/updated will not be executed in this case and an already logged in user with the same username will not be disconnected.
 

+ 0 - 68
httpd/httpd_test.go

@@ -824,74 +824,6 @@ func TestAddUserInvalidVirtualFolders(t *testing.T) {
 	})
 	_, _, err = httpdtest.AddUser(u, http.StatusBadRequest)
 	assert.NoError(t, err)
-	/*u.VirtualFolders = nil
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir", "subdir"),
-			Name:       folderName + "2",
-		},
-		VirtualPath: "/vdir1",
-	})
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), // invalid, contains mapped_dir/subdir
-			Name:       folderName,
-		},
-		VirtualPath: "/vdir2",
-	})
-	_, _, err = httpdtest.AddUser(u, http.StatusBadRequest)
-	assert.NoError(t, err)
-	u.VirtualFolders = nil
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir"),
-			Name:       folderName,
-		},
-		VirtualPath: "/vdir1",
-	})
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir", "subdir"), // invalid, contained in mapped_dir
-			Name:       folderName + "3",
-		},
-		VirtualPath: "/vdir2",
-	})
-	_, _, err = httpdtest.AddUser(u, http.StatusBadRequest)
-	assert.NoError(t, err)
-	u.VirtualFolders = nil
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"),
-			Name:       folderName + "1",
-		},
-		VirtualPath: "/vdir1/subdir",
-	})
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir2"),
-			Name:       folderName + "2",
-		},
-		VirtualPath: "/vdir1/../vdir1", // invalid, overlaps with /vdir1/subdir
-	})
-	_, _, err = httpdtest.AddUser(u, http.StatusBadRequest)
-	assert.NoError(t, err)
-	u.VirtualFolders = nil
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"),
-			Name:       folderName + "1",
-		},
-		VirtualPath: "/vdir1/",
-	})
-	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
-		BaseVirtualFolder: vfs.BaseVirtualFolder{
-			MappedPath: filepath.Join(os.TempDir(), "mapped_dir2"),
-			Name:       folderName + "2",
-		},
-		VirtualPath: "/vdir1/subdir", // invalid, contained inside /vdir1
-	})
-	_, _, err = httpdtest.AddUser(u, http.StatusBadRequest)
-	assert.NoError(t, err)*/
 	u.VirtualFolders = nil
 	u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
 		BaseVirtualFolder: vfs.BaseVirtualFolder{

+ 32 - 0
kms/kms.go

@@ -197,6 +197,26 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
 	return nil
 }
 
+// IsEqual returns true if all the secrets fields are equal
+func (s *Secret) IsEqual(other *Secret) bool {
+	if s.GetStatus() != other.GetStatus() {
+		return false
+	}
+	if s.GetPayload() != other.GetPayload() {
+		return false
+	}
+	if s.GetKey() != other.GetKey() {
+		return false
+	}
+	if s.GetAdditionalData() != other.GetAdditionalData() {
+		return false
+	}
+	if s.GetMode() != other.GetMode() {
+		return false
+	}
+	return true
+}
+
 // Clone returns a copy of the secret object
 func (s *Secret) Clone() *Secret {
 	s.RLock()
@@ -414,3 +434,15 @@ func (s *Secret) Decrypt() error {
 
 	return s.provider.Decrypt()
 }
+
+// TryDecrypt decrypts a Secret object if encrypted.
+// It returns a nil error if the object is not encrypted
+func (s *Secret) TryDecrypt() error {
+	s.Lock()
+	defer s.Unlock()
+
+	if s.provider.IsEncrypted() {
+		return s.provider.Decrypt()
+	}
+	return nil
+}

+ 3 - 5
vfs/azblobfs.go

@@ -67,11 +67,9 @@ func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsCo
 	if err := fs.config.Validate(); err != nil {
 		return fs, err
 	}
-	if fs.config.AccountKey.IsEncrypted() {
-		err := fs.config.AccountKey.Decrypt()
-		if err != nil {
-			return fs, err
-		}
+
+	if err := fs.config.AccountKey.TryDecrypt(); err != nil {
+		return fs, err
 	}
 	fs.setConfigDefaults()
 

+ 2 - 4
vfs/cryptfs.go

@@ -35,10 +35,8 @@ func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (
 	if err := config.Validate(); err != nil {
 		return nil, err
 	}
-	if config.Passphrase.IsEncrypted() {
-		if err := config.Passphrase.Decrypt(); err != nil {
-			return nil, err
-		}
+	if err := config.Passphrase.TryDecrypt(); err != nil {
+		return nil, err
 	}
 	fs := &CryptFs{
 		OsFs: &OsFs{

+ 21 - 0
vfs/filesystem.go

@@ -72,6 +72,27 @@ func (f *Filesystem) SetNilSecretsIfEmpty() {
 	}
 }
 
+// IsEqual returns true if the fs is equal to other
+func (f *Filesystem) IsEqual(other *Filesystem) bool {
+	if f.Provider != other.Provider {
+		return false
+	}
+	switch f.Provider {
+	case S3FilesystemProvider:
+		return f.S3Config.isEqual(&other.S3Config)
+	case GCSFilesystemProvider:
+		return f.GCSConfig.isEqual(&other.GCSConfig)
+	case AzureBlobFilesystemProvider:
+		return f.AzBlobConfig.isEqual(&other.AzBlobConfig)
+	case CryptedFilesystemProvider:
+		return f.CryptConfig.isEqual(&other.CryptConfig)
+	case SFTPFilesystemProvider:
+		return f.SFTPConfig.isEqual(&other.SFTPConfig)
+	default:
+		return true
+	}
+}
+
 // GetACopy returns a copy
 func (f *Filesystem) GetACopy() Filesystem {
 	f.SetEmptySecretsIfNil()

+ 3 - 5
vfs/gcsfs.go

@@ -71,11 +71,9 @@ func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig)
 	if fs.config.AutomaticCredentials > 0 {
 		fs.svc, err = storage.NewClient(ctx)
 	} else if !fs.config.Credentials.IsEmpty() {
-		if fs.config.Credentials.IsEncrypted() {
-			err = fs.config.Credentials.Decrypt()
-			if err != nil {
-				return fs, err
-			}
+		err = fs.config.Credentials.TryDecrypt()
+		if err != nil {
+			return fs, err
 		}
 		fs.svc, err = storage.NewClient(ctx, option.WithCredentialsJSON([]byte(fs.config.Credentials.GetPayload())))
 	} else {

+ 2 - 5
vfs/s3fs.go

@@ -68,11 +68,8 @@ func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (F
 	}
 
 	if !fs.config.AccessSecret.IsEmpty() {
-		if fs.config.AccessSecret.IsEncrypted() {
-			err := fs.config.AccessSecret.Decrypt()
-			if err != nil {
-				return fs, err
-			}
+		if err := fs.config.AccessSecret.TryDecrypt(); err != nil {
+			return fs, err
 		}
 		awsConfig.Credentials = credentials.NewStaticCredentials(fs.config.AccessKey, fs.config.AccessSecret.GetPayload(), "")
 	}

+ 40 - 4
vfs/sftpfs.go

@@ -44,6 +44,35 @@ type SFTPFsConfig struct {
 	DisableCouncurrentReads bool `json:"disable_concurrent_reads,omitempty"`
 }
 
+func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool {
+	if c.Endpoint != other.Endpoint {
+		return false
+	}
+	if c.Username != other.Username {
+		return false
+	}
+	if c.Prefix != other.Prefix {
+		return false
+	}
+	if c.DisableCouncurrentReads != other.DisableCouncurrentReads {
+		return false
+	}
+	if len(c.Fingerprints) != len(other.Fingerprints) {
+		return false
+	}
+	for _, fp := range c.Fingerprints {
+		if !utils.IsStringInSlice(fp, other.Fingerprints) {
+			return false
+		}
+	}
+	c.setEmptyCredentialsIfNil()
+	other.setEmptyCredentialsIfNil()
+	if !c.Password.IsEqual(other.Password) {
+		return false
+	}
+	return c.PrivateKey.IsEqual(other.PrivateKey)
+}
+
 func (c *SFTPFsConfig) setEmptyCredentialsIfNil() {
 	if c.Password == nil {
 		c.Password = kms.NewEmptySecret()
@@ -123,13 +152,13 @@ func NewSFTPFs(connectionID, mountPath string, config SFTPFsConfig) (Fs, error)
 	if err := config.Validate(); err != nil {
 		return nil, err
 	}
-	if !config.Password.IsEmpty() && config.Password.IsEncrypted() {
-		if err := config.Password.Decrypt(); err != nil {
+	if !config.Password.IsEmpty() {
+		if err := config.Password.TryDecrypt(); err != nil {
 			return nil, err
 		}
 	}
-	if !config.PrivateKey.IsEmpty() && config.PrivateKey.IsEncrypted() {
-		if err := config.PrivateKey.Decrypt(); err != nil {
+	if !config.PrivateKey.IsEmpty() {
+		if err := config.PrivateKey.TryDecrypt(); err != nil {
 			return nil, err
 		}
 	}
@@ -339,6 +368,13 @@ func (*SFTPFs) IsNotSupported(err error) bool {
 
 // CheckRootPath creates the specified local root directory if it does not exists
 func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool {
+	if fs.config.Prefix == "/" {
+		return true
+	}
+	if err := fs.MkdirAll(fs.config.Prefix, uid, gid); err != nil {
+		fsLog(fs, logger.LevelDebug, "error creating root directory %#v for user %#v: %v", fs.config.Prefix, username, err)
+		return false
+	}
 	return true
 }
 

+ 103 - 0
vfs/vfs.go

@@ -147,6 +147,40 @@ type S3FsConfig struct {
 	UploadConcurrency int `json:"upload_concurrency,omitempty"`
 }
 
+func (c *S3FsConfig) isEqual(other *S3FsConfig) bool {
+	if c.Bucket != other.Bucket {
+		return false
+	}
+	if c.KeyPrefix != other.KeyPrefix {
+		return false
+	}
+	if c.Region != other.Region {
+		return false
+	}
+	if c.AccessKey != other.AccessKey {
+		return false
+	}
+	if c.Endpoint != other.Endpoint {
+		return false
+	}
+	if c.StorageClass != other.StorageClass {
+		return false
+	}
+	if c.UploadPartSize != other.UploadPartSize {
+		return false
+	}
+	if c.UploadConcurrency != other.UploadConcurrency {
+		return false
+	}
+	if c.AccessSecret == nil {
+		c.AccessSecret = kms.NewEmptySecret()
+	}
+	if other.AccessSecret == nil {
+		other.AccessSecret = kms.NewEmptySecret()
+	}
+	return c.AccessSecret.IsEqual(other.AccessSecret)
+}
+
 func (c *S3FsConfig) checkCredentials() error {
 	if c.AccessKey == "" && !c.AccessSecret.IsEmpty() {
 		return errors.New("access_key cannot be empty with access_secret not empty")
@@ -224,6 +258,28 @@ type GCSFsConfig struct {
 	StorageClass         string `json:"storage_class,omitempty"`
 }
 
+func (c *GCSFsConfig) isEqual(other *GCSFsConfig) bool {
+	if c.Bucket != other.Bucket {
+		return false
+	}
+	if c.KeyPrefix != other.KeyPrefix {
+		return false
+	}
+	if c.AutomaticCredentials != other.AutomaticCredentials {
+		return false
+	}
+	if c.StorageClass != other.StorageClass {
+		return false
+	}
+	if c.Credentials == nil {
+		c.Credentials = kms.NewEmptySecret()
+	}
+	if other.Credentials == nil {
+		other.Credentials = kms.NewEmptySecret()
+	}
+	return c.Credentials.IsEqual(other.Credentials)
+}
+
 // Validate returns an error if the configuration is not valid
 func (c *GCSFsConfig) Validate(credentialsFilePath string) error {
 	if c.Credentials == nil {
@@ -293,6 +349,43 @@ type AzBlobFsConfig struct {
 	AccessTier string `json:"access_tier,omitempty"`
 }
 
+func (c *AzBlobFsConfig) isEqual(other *AzBlobFsConfig) bool {
+	if c.Container != other.Container {
+		return false
+	}
+	if c.AccountName != other.AccountName {
+		return false
+	}
+	if c.Endpoint != other.Endpoint {
+		return false
+	}
+	if c.SASURL != other.SASURL {
+		return false
+	}
+	if c.KeyPrefix != other.KeyPrefix {
+		return false
+	}
+	if c.UploadPartSize != other.UploadPartSize {
+		return false
+	}
+	if c.UploadConcurrency != other.UploadConcurrency {
+		return false
+	}
+	if c.UseEmulator != other.UseEmulator {
+		return false
+	}
+	if c.AccessTier != other.AccessTier {
+		return false
+	}
+	if c.AccountKey == nil {
+		c.AccountKey = kms.NewEmptySecret()
+	}
+	if other.AccountKey == nil {
+		other.AccountKey = kms.NewEmptySecret()
+	}
+	return c.AccountKey.IsEqual(other.AccountKey)
+}
+
 // EncryptCredentials encrypts access secret if it is in plain text
 func (c *AzBlobFsConfig) EncryptCredentials(additionalData string) error {
 	if c.AccountKey.IsPlain() {
@@ -355,6 +448,16 @@ type CryptFsConfig struct {
 	Passphrase *kms.Secret `json:"passphrase,omitempty"`
 }
 
+func (c *CryptFsConfig) isEqual(other *CryptFsConfig) bool {
+	if c.Passphrase == nil {
+		c.Passphrase = kms.NewEmptySecret()
+	}
+	if other.Passphrase == nil {
+		other.Passphrase = kms.NewEmptySecret()
+	}
+	return c.Passphrase.IsEqual(other.Passphrase)
+}
+
 // EncryptCredentials encrypts access secret if it is in plain text
 func (c *CryptFsConfig) EncryptCredentials(additionalData string) error {
 	if c.Passphrase.IsPlain() {

+ 40 - 31
webdavd/internal_test.go

@@ -895,6 +895,7 @@ func TestBasicUsersCache(t *testing.T) {
 			},
 		},
 	}
+	dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize)
 	server := webDavServer{
 		config:  c,
 		binding: c.Bindings[0],
@@ -915,10 +916,8 @@ func TestBasicUsersCache(t *testing.T) {
 	assert.False(t, isCached)
 	assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
 	// now the user should be cached
-	var cachedUser *dataprovider.CachedUser
-	result, ok := dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok := dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.False(t, cachedUser.IsExpired())
 		assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
 		// authenticate must return the cached user now
@@ -935,10 +934,9 @@ func TestBasicUsersCache(t *testing.T) {
 
 	// force cached user expiration
 	cachedUser.Expiration = now
-	dataprovider.CacheWebDAVUser(cachedUser, c.Cache.Users.MaxSize)
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	dataprovider.CacheWebDAVUser(cachedUser)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.True(t, cachedUser.IsExpired())
 	}
 	// now authenticate should get the user from the data provider and update the cache
@@ -946,12 +944,24 @@ func TestBasicUsersCache(t *testing.T) {
 	assert.NoError(t, err)
 	assert.False(t, isCached)
 	assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.False(t, cachedUser.IsExpired())
 	}
-	// cache is invalidated after a user modification
+	// cache is not invalidated after a user modification if the fs does not change
+	err = dataprovider.UpdateUser(&user)
+	assert.NoError(t, err)
+	_, ok = dataprovider.GetCachedWebDAVUser(username)
+	assert.True(t, ok)
+	folderName := "testFolder"
+	user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{
+		BaseVirtualFolder: vfs.BaseVirtualFolder{
+			Name:       folderName,
+			MappedPath: filepath.Join(os.TempDir(), "mapped"),
+		},
+		VirtualPath: "/vdir",
+	})
+
 	err = dataprovider.UpdateUser(&user)
 	assert.NoError(t, err)
 	_, ok = dataprovider.GetCachedWebDAVUser(username)
@@ -969,6 +979,9 @@ func TestBasicUsersCache(t *testing.T) {
 	_, ok = dataprovider.GetCachedWebDAVUser(username)
 	assert.False(t, ok)
 
+	err = dataprovider.DeleteFolder(folderName)
+	assert.NoError(t, err)
+
 	err = os.RemoveAll(u.GetHomeDir())
 	assert.NoError(t, err)
 }
@@ -1011,6 +1024,7 @@ func TestCachedUserWithFolders(t *testing.T) {
 			},
 		},
 	}
+	dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize)
 	server := webDavServer{
 		config:  c,
 		binding: c.Bindings[0],
@@ -1031,10 +1045,8 @@ func TestCachedUserWithFolders(t *testing.T) {
 	assert.False(t, isCached)
 	assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
 	// now the user should be cached
-	var cachedUser *dataprovider.CachedUser
-	result, ok := dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok := dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.False(t, cachedUser.IsExpired())
 		assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
 		// authenticate must return the cached user now
@@ -1054,9 +1066,8 @@ func TestCachedUserWithFolders(t *testing.T) {
 	assert.NoError(t, err)
 	assert.False(t, isCached)
 	assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.False(t, cachedUser.IsExpired())
 	}
 
@@ -1067,9 +1078,8 @@ func TestCachedUserWithFolders(t *testing.T) {
 	assert.NoError(t, err)
 	assert.False(t, isCached)
 	assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser = result.(*dataprovider.CachedUser)
 		assert.False(t, cachedUser.IsExpired())
 	}
 
@@ -1133,6 +1143,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
 			},
 		},
 	}
+	dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize)
 	server := webDavServer{
 		config:  c,
 		binding: c.Bindings[0],
@@ -1240,6 +1251,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
 	assert.True(t, ok)
 
 	// now remove user1 after an update
+	user1.HomeDir += "_mod"
 	err = dataprovider.UpdateUser(&user1)
 	assert.NoError(t, err)
 	_, ok = dataprovider.GetCachedWebDAVUser(user1.Username)
@@ -1283,6 +1295,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
 }
 
 func TestUserCacheIsolation(t *testing.T) {
+	dataprovider.InitializeWebDAVUserCache(10)
 	username := "webdav_internal_cache_test"
 	password := "dav_pwd"
 	u := dataprovider.User{
@@ -1307,31 +1320,27 @@ func TestUserCacheIsolation(t *testing.T) {
 	cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret")
 	err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt()
 	assert.NoError(t, err)
-
-	dataprovider.CacheWebDAVUser(cachedUser, 10)
-	result, ok := dataprovider.GetCachedWebDAVUser(username)
+	dataprovider.CacheWebDAVUser(cachedUser)
+	cachedUser, ok := dataprovider.GetCachedWebDAVUser(username)
 
 	if assert.True(t, ok) {
-		cachedUser := result.(*dataprovider.CachedUser).User
-		_, err = cachedUser.GetFilesystem("")
+		_, err = cachedUser.User.GetFilesystem("")
 		assert.NoError(t, err)
 		// the filesystem is now cached
 	}
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser := result.(*dataprovider.CachedUser).User
-		assert.True(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted())
-		err = cachedUser.FsConfig.S3Config.AccessSecret.Decrypt()
+		assert.True(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted())
+		err = cachedUser.User.FsConfig.S3Config.AccessSecret.Decrypt()
 		assert.NoError(t, err)
-		cachedUser.FsConfig.Provider = vfs.S3FilesystemProvider
-		_, err = cachedUser.GetFilesystem("")
+		cachedUser.User.FsConfig.Provider = vfs.S3FilesystemProvider
+		_, err = cachedUser.User.GetFilesystem("")
 		assert.Error(t, err, "we don't have to get the previously cached filesystem!")
 	}
-	result, ok = dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok = dataprovider.GetCachedWebDAVUser(username)
 	if assert.True(t, ok) {
-		cachedUser := result.(*dataprovider.CachedUser).User
-		assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.FsConfig.Provider)
-		assert.False(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted())
+		assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.User.FsConfig.Provider)
+		assert.False(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted())
 	}
 
 	err = dataprovider.DeleteUser(username)

+ 5 - 4
webdavd/server.go

@@ -171,6 +171,8 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	connectionID, err := s.validateUser(&user, r, loginMethod)
 	if err != nil {
+		// remove the cached user, we have not yet validated its filesystem
+		dataprovider.RemoveCachedWebDAVUser(user.Username)
 		updateLoginMetrics(&user, ipAddr, loginMethod, err)
 		http.Error(w, err.Error(), http.StatusForbidden)
 		return
@@ -246,9 +248,8 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us
 	if !ok {
 		return user, false, nil, loginMethod, err401
 	}
-	result, ok := dataprovider.GetCachedWebDAVUser(username)
+	cachedUser, ok := dataprovider.GetCachedWebDAVUser(username)
 	if ok {
-		cachedUser := result.(*dataprovider.CachedUser)
 		if cachedUser.IsExpired() {
 			dataprovider.RemoveCachedWebDAVUser(username)
 		} else {
@@ -272,7 +273,7 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us
 		return user, false, nil, loginMethod, err
 	}
 	lockSystem := webdav.NewMemLS()
-	cachedUser := &dataprovider.CachedUser{
+	cachedUser = &dataprovider.CachedUser{
 		User:       user,
 		Password:   password,
 		LockSystem: lockSystem,
@@ -280,7 +281,7 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us
 	if s.config.Cache.Users.ExpirationTime > 0 {
 		cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute)
 	}
-	dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize)
+	dataprovider.CacheWebDAVUser(cachedUser)
 	return user, false, lockSystem, loginMethod, nil
 }
 

+ 2 - 0
webdavd/webdavd.go

@@ -8,6 +8,7 @@ import (
 	"github.com/go-chi/chi/v5/middleware"
 
 	"github.com/drakkan/sftpgo/common"
+	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/logger"
 	"github.com/drakkan/sftpgo/utils"
 )
@@ -178,6 +179,7 @@ func (c *Configuration) Initialize(configDir string) error {
 		certMgr = mgr
 	}
 	compressor := middleware.NewCompressor(5, "text/*")
+	dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize)
 
 	serviceStatus = ServiceStatus{
 		Bindings: nil,

+ 6 - 0
webdavd/webdavd_test.go

@@ -808,11 +808,15 @@ func TestPreLoginHook(t *testing.T) {
 	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm)
 	assert.NoError(t, err)
 	// update the user to remove it from the cache
+	user.FsConfig.Provider = vfs.CryptedFilesystemProvider
+	user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword)
 	user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
 	assert.NoError(t, err)
 	client = getWebDavClient(user, true, nil)
 	assert.Error(t, checkBasicFunc(client))
 	// update the user to remove it from the cache
+	user.FsConfig.Provider = vfs.LocalFilesystemProvider
+	user.FsConfig.CryptConfig.Passphrase = nil
 	user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
 	assert.NoError(t, err)
 	user.Status = 0
@@ -2037,11 +2041,13 @@ func TestPreLoginHookWithClientCert(t *testing.T) {
 	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm)
 	assert.NoError(t, err)
 	// update the user to remove it from the cache
+	user.Password = defaultPassword
 	user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
 	assert.NoError(t, err)
 	client = getWebDavClient(user, true, tlsConfig)
 	assert.Error(t, checkBasicFunc(client))
 	// update the user to remove it from the cache
+	user.Password = defaultPassword
 	user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
 	assert.NoError(t, err)
 	user.Status = 0