| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 | 
							- package webdavd
 
- import (
 
- 	"context"
 
- 	"crypto/tls"
 
- 	"crypto/x509"
 
- 	"errors"
 
- 	"fmt"
 
- 	"log"
 
- 	"net"
 
- 	"net/http"
 
- 	"path"
 
- 	"path/filepath"
 
- 	"runtime/debug"
 
- 	"time"
 
- 	"github.com/go-chi/chi/v5/middleware"
 
- 	"github.com/rs/cors"
 
- 	"github.com/rs/xid"
 
- 	"golang.org/x/net/webdav"
 
- 	"github.com/drakkan/sftpgo/common"
 
- 	"github.com/drakkan/sftpgo/dataprovider"
 
- 	"github.com/drakkan/sftpgo/logger"
 
- 	"github.com/drakkan/sftpgo/metrics"
 
- 	"github.com/drakkan/sftpgo/utils"
 
- )
 
- var (
 
- 	err401 = errors.New("Unauthorized")
 
- )
 
- type webDavServer struct {
 
- 	config  *Configuration
 
- 	binding Binding
 
- }
 
- func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error {
 
- 	handler := compressor.Handler(s)
 
- 	httpServer := &http.Server{
 
- 		ReadHeaderTimeout: 30 * time.Second,
 
- 		ReadTimeout:       60 * time.Second,
 
- 		WriteTimeout:      60 * time.Second,
 
- 		IdleTimeout:       60 * time.Second,
 
- 		MaxHeaderBytes:    1 << 16, // 64KB
 
- 		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
 
- 	}
 
- 	if s.config.Cors.Enabled {
 
- 		c := cors.New(cors.Options{
 
- 			AllowedOrigins:     s.config.Cors.AllowedOrigins,
 
- 			AllowedMethods:     s.config.Cors.AllowedMethods,
 
- 			AllowedHeaders:     s.config.Cors.AllowedHeaders,
 
- 			ExposedHeaders:     s.config.Cors.ExposedHeaders,
 
- 			MaxAge:             s.config.Cors.MaxAge,
 
- 			AllowCredentials:   s.config.Cors.AllowCredentials,
 
- 			OptionsPassthrough: true,
 
- 		})
 
- 		handler = c.Handler(handler)
 
- 	}
 
- 	httpServer.Handler = handler
 
- 	if certMgr != nil && s.binding.EnableHTTPS {
 
- 		serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding)
 
- 		httpServer.TLSConfig = &tls.Config{
 
- 			GetCertificate:           certMgr.GetCertificateFunc(),
 
- 			MinVersion:               tls.VersionTLS12,
 
- 			CipherSuites:             utils.GetTLSCiphersFromNames(s.binding.TLSCipherSuites),
 
- 			PreferServerCipherSuites: true,
 
- 		}
 
- 		logger.Debug(logSender, "", "configured TLS cipher suites for binding %#v: %v", s.binding.GetAddress(),
 
- 			httpServer.TLSConfig.CipherSuites)
 
- 		if s.binding.isMutualTLSEnabled() {
 
- 			httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs()
 
- 			httpServer.TLSConfig.VerifyConnection = s.verifyTLSConnection
 
- 			switch s.binding.ClientAuthType {
 
- 			case 1:
 
- 				httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
 
- 			case 2:
 
- 				httpServer.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven
 
- 			}
 
- 		}
 
- 		return utils.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, logSender)
 
- 	}
 
- 	s.binding.EnableHTTPS = false
 
- 	serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding)
 
- 	return utils.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, logSender)
 
- }
 
- func (s *webDavServer) verifyTLSConnection(state tls.ConnectionState) error {
 
- 	if certMgr != nil {
 
- 		var clientCrt *x509.Certificate
 
- 		var clientCrtName string
 
- 		if len(state.PeerCertificates) > 0 {
 
- 			clientCrt = state.PeerCertificates[0]
 
- 			clientCrtName = clientCrt.Subject.String()
 
- 		}
 
- 		if len(state.VerifiedChains) == 0 {
 
- 			if s.binding.ClientAuthType == 2 {
 
- 				return nil
 
- 			}
 
- 			logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain")
 
- 			return errors.New("TLS connection cannot be verified: unable to get verification chain")
 
- 		}
 
- 		for _, verifiedChain := range state.VerifiedChains {
 
- 			var caCrt *x509.Certificate
 
- 			if len(verifiedChain) > 0 {
 
- 				caCrt = verifiedChain[len(verifiedChain)-1]
 
- 			}
 
- 			if certMgr.IsRevoked(clientCrt, caCrt) {
 
- 				logger.Debug(logSender, "", "tls handshake error, client certificate %#v has been revoked", clientCrtName)
 
- 				return common.ErrCrtRevoked
 
- 			}
 
- 		}
 
- 	}
 
- 	return nil
 
- }
 
- // returns true if we have to handle a HEAD response, for a directory, ourself
 
- func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request, connection *Connection) bool {
 
- 	// see RFC4918, section 9.4
 
- 	if r.Method == http.MethodGet || r.Method == http.MethodHead {
 
- 		p := path.Clean(r.URL.Path)
 
- 		info, err := connection.Stat(ctx, p)
 
- 		if err == nil && info.IsDir() {
 
- 			if r.Method == http.MethodHead {
 
- 				return true
 
- 			}
 
- 			r.Method = "PROPFIND"
 
- 			if r.Header.Get("Depth") == "" {
 
- 				r.Header.Add("Depth", "1")
 
- 			}
 
- 		}
 
- 	}
 
- 	return false
 
- }
 
- // ServeHTTP implements the http.Handler interface
 
- func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
- 	defer func() {
 
- 		if r := recover(); r != nil {
 
- 			logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack()))
 
- 			http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
 
- 		}
 
- 	}()
 
- 	ipAddr := s.checkRemoteAddress(r)
 
- 	common.Connections.AddClientConnection(ipAddr)
 
- 	defer common.Connections.RemoveClientConnection(ipAddr)
 
- 	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
 
- 		logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
 
- 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
 
- 		return
 
- 	}
 
- 	if common.IsBanned(ipAddr) {
 
- 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
 
- 		return
 
- 	}
 
- 	delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr)
 
- 	if err != nil {
 
- 		delay += 499999999 * time.Nanosecond
 
- 		w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
 
- 		w.Header().Set("X-Retry-In", delay.String())
 
- 		http.Error(w, err.Error(), http.StatusTooManyRequests)
 
- 		return
 
- 	}
 
- 	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolWebDAV); err != nil {
 
- 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
 
- 		return
 
- 	}
 
- 	user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr)
 
- 	if err != nil {
 
- 		w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
 
- 		http.Error(w, fmt.Sprintf("Authentication error: %v", err), http.StatusUnauthorized)
 
- 		return
 
- 	}
 
- 	connectionID, err := s.validateUser(&user, r, loginMethod)
 
- 	if err != nil {
 
- 		// remove the cached user, we have not yet validated its filesystem
 
- 		dataprovider.RemoveCachedWebDAVUser(user.Username)
 
- 		updateLoginMetrics(&user, ipAddr, loginMethod, err)
 
- 		http.Error(w, err.Error(), http.StatusForbidden)
 
- 		return
 
- 	}
 
- 	if !isCached {
 
- 		err = user.CheckFsRoot(connectionID)
 
- 	} else {
 
- 		_, err = user.GetFilesystem(connectionID)
 
- 	}
 
- 	if err != nil {
 
- 		errClose := user.CloseFs()
 
- 		logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
 
- 		updateLoginMetrics(&user, ipAddr, loginMethod, err)
 
- 		http.Error(w, err.Error(), http.StatusInternalServerError)
 
- 		return
 
- 	}
 
- 	updateLoginMetrics(&user, ipAddr, loginMethod, err)
 
- 	ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
 
- 	ctx = context.WithValue(ctx, requestStartKey, time.Now())
 
- 	connection := &Connection{
 
- 		BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, r.RemoteAddr, user),
 
- 		request:        r,
 
- 	}
 
- 	common.Connections.Add(connection)
 
- 	defer common.Connections.Remove(connection.GetID())
 
- 	dataprovider.UpdateLastLogin(&user) //nolint:errcheck
 
- 	if s.checkRequestMethod(ctx, r, connection) {
 
- 		w.Header().Set("Content-Type", "text/xml; charset=utf-8")
 
- 		w.WriteHeader(http.StatusMultiStatus)
 
- 		w.Write([]byte("")) //nolint:errcheck
 
- 		return
 
- 	}
 
- 	handler := webdav.Handler{
 
- 		Prefix:     s.binding.Prefix,
 
- 		FileSystem: connection,
 
- 		LockSystem: lockSystem,
 
- 		Logger:     writeLog,
 
- 	}
 
- 	handler.ServeHTTP(w, r.WithContext(ctx))
 
- }
 
- func (s *webDavServer) getCredentialsAndLoginMethod(r *http.Request) (string, string, string, *x509.Certificate, bool) {
 
- 	var tlsCert *x509.Certificate
 
- 	loginMethod := dataprovider.LoginMethodPassword
 
- 	username, password, ok := r.BasicAuth()
 
- 	if s.binding.isMutualTLSEnabled() && r.TLS != nil {
 
- 		if len(r.TLS.PeerCertificates) > 0 {
 
- 			tlsCert = r.TLS.PeerCertificates[0]
 
- 			if ok {
 
- 				loginMethod = dataprovider.LoginMethodTLSCertificateAndPwd
 
- 			} else {
 
- 				loginMethod = dataprovider.LoginMethodTLSCertificate
 
- 				username = tlsCert.Subject.CommonName
 
- 				password = ""
 
- 			}
 
- 			ok = true
 
- 		}
 
- 	}
 
- 	return username, password, loginMethod, tlsCert, ok
 
- }
 
- func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.User, bool, webdav.LockSystem, string, error) {
 
- 	var user dataprovider.User
 
- 	var err error
 
- 	username, password, loginMethod, tlsCert, ok := s.getCredentialsAndLoginMethod(r)
 
- 	if !ok {
 
- 		return user, false, nil, loginMethod, err401
 
- 	}
 
- 	cachedUser, ok := dataprovider.GetCachedWebDAVUser(username)
 
- 	if ok {
 
- 		if cachedUser.IsExpired() {
 
- 			dataprovider.RemoveCachedWebDAVUser(username)
 
- 		} else {
 
- 			if !cachedUser.User.IsTLSUsernameVerificationEnabled() {
 
- 				// for backward compatibility with 2.0.x we only check the password
 
- 				tlsCert = nil
 
- 				loginMethod = dataprovider.LoginMethodPassword
 
- 			}
 
- 			if err := dataprovider.CheckCachedUserCredentials(cachedUser, password, loginMethod, common.ProtocolWebDAV, tlsCert); err == nil {
 
- 				return cachedUser.User, true, cachedUser.LockSystem, loginMethod, nil
 
- 			}
 
- 			updateLoginMetrics(&cachedUser.User, ip, loginMethod, dataprovider.ErrInvalidCredentials)
 
- 			return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials
 
- 		}
 
- 	}
 
- 	user, loginMethod, err = dataprovider.CheckCompositeCredentials(username, password, ip, loginMethod,
 
- 		common.ProtocolWebDAV, tlsCert)
 
- 	if err != nil {
 
- 		user.Username = username
 
- 		updateLoginMetrics(&user, ip, loginMethod, err)
 
- 		return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials
 
- 	}
 
- 	lockSystem := webdav.NewMemLS()
 
- 	cachedUser = &dataprovider.CachedUser{
 
- 		User:       user,
 
- 		Password:   password,
 
- 		LockSystem: lockSystem,
 
- 	}
 
- 	if s.config.Cache.Users.ExpirationTime > 0 {
 
- 		cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute)
 
- 	}
 
- 	dataprovider.CacheWebDAVUser(cachedUser)
 
- 	return user, false, lockSystem, loginMethod, nil
 
- }
 
- func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, loginMethod string) (string, error) {
 
- 	connID := xid.New().String()
 
- 	connectionID := fmt.Sprintf("%v_%v", common.ProtocolWebDAV, connID)
 
- 	if !filepath.IsAbs(user.HomeDir) {
 
- 		logger.Warn(logSender, connectionID, "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
 
- 			user.Username, user.HomeDir)
 
- 		return connID, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir)
 
- 	}
 
- 	if utils.IsStringInSlice(common.ProtocolWebDAV, user.Filters.DeniedProtocols) {
 
- 		logger.Debug(logSender, connectionID, "cannot login user %#v, protocol DAV is not allowed", user.Username)
 
- 		return connID, fmt.Errorf("protocol DAV is not allowed for user %#v", user.Username)
 
- 	}
 
- 	if !user.IsLoginMethodAllowed(loginMethod, nil) {
 
- 		logger.Debug(logSender, connectionID, "cannot login user %#v, %v login method is not allowed", user.Username, loginMethod)
 
- 		return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
 
- 	}
 
- 	if user.MaxSessions > 0 {
 
- 		activeSessions := common.Connections.GetActiveSessions(user.Username)
 
- 		if activeSessions >= user.MaxSessions {
 
- 			logger.Debug(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,
 
- 				activeSessions, user.MaxSessions)
 
- 			return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
 
- 		}
 
- 	}
 
- 	if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
 
- 		logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
 
- 		return connID, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
 
- 	}
 
- 	return connID, nil
 
- }
 
- func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
 
- 	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
 
- 	ip := net.ParseIP(ipAddr)
 
- 	if ip != nil {
 
- 		for _, allow := range s.binding.allowHeadersFrom {
 
- 			if allow(ip) {
 
- 				parsedIP := utils.GetRealIP(r)
 
- 				if parsedIP != "" {
 
- 					ipAddr = parsedIP
 
- 					r.RemoteAddr = ipAddr
 
- 				}
 
- 				break
 
- 			}
 
- 		}
 
- 	}
 
- 	return ipAddr
 
- }
 
- func writeLog(r *http.Request, err error) {
 
- 	scheme := "http"
 
- 	if r.TLS != nil {
 
- 		scheme = "https"
 
- 	}
 
- 	fields := map[string]interface{}{
 
- 		"remote_addr": r.RemoteAddr,
 
- 		"proto":       r.Proto,
 
- 		"method":      r.Method,
 
- 		"user_agent":  r.UserAgent(),
 
- 		"uri":         fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)}
 
- 	if reqID, ok := r.Context().Value(requestIDKey).(string); ok {
 
- 		fields["request_id"] = reqID
 
- 	}
 
- 	if reqStart, ok := r.Context().Value(requestStartKey).(time.Time); ok {
 
- 		fields["elapsed_ms"] = time.Since(reqStart).Nanoseconds() / 1000000
 
- 	}
 
- 	logger.GetLogger().Info().
 
- 		Timestamp().
 
- 		Str("sender", logSender).
 
- 		Fields(fields).
 
- 		Err(err).
 
- 		Send()
 
- }
 
- func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) {
 
- 	metrics.AddLoginAttempt(loginMethod)
 
- 	if err != nil {
 
- 		logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, err.Error())
 
- 		event := common.HostEventLoginFailed
 
- 		if _, ok := err.(*dataprovider.RecordNotFoundError); ok {
 
- 			event = common.HostEventUserNotFound
 
- 		}
 
- 		common.AddDefenderEvent(ip, event)
 
- 	}
 
- 	metrics.AddLoginResult(loginMethod, err)
 
- 	dataprovider.ExecutePostLoginHook(user, loginMethod, ip, common.ProtocolWebDAV, err)
 
- }
 
 
  |