Browse Source

httpd/webdav: use a custom listener with read and write deadlines

Nicola Murino 4 years ago
parent
commit
4ea4202b99

+ 5 - 5
.github/workflows/development.yml

@@ -57,11 +57,11 @@ jobs:
       - name: Run test cases using bolt provider
         run: |
           go test -v -p 1 -timeout 2m ./config -covermode=atomic
-          go test -v -p 1 -timeout 2m ./common -covermode=atomic
-          go test -v -p 1 -timeout 3m ./httpd -covermode=atomic
+          go test -v -p 1 -timeout 5m ./common -covermode=atomic
+          go test -v -p 1 -timeout 5m ./httpd -covermode=atomic
           go test -v -p 1 -timeout 8m ./sftpd -covermode=atomic
-          go test -v -p 1 -timeout 2m ./ftpd -covermode=atomic
-          go test -v -p 1 -timeout 2m ./webdavd -covermode=atomic
+          go test -v -p 1 -timeout 5m ./ftpd -covermode=atomic
+          go test -v -p 1 -timeout 5m ./webdavd -covermode=atomic
           go test -v -p 1 -timeout 2m ./telemetry -covermode=atomic
         env:
           SFTPGO_DATA_PROVIDER__DRIVER: bolt
@@ -302,7 +302,7 @@ jobs:
         with:
           go-version: 1.16
       - name: Run golangci-lint
-        uses: golangci/golangci-lint-action@v2.5.2
+        uses: golangci/golangci-lint-action@v2
         with:
           version: latest
           skip-go-installation: true

+ 2 - 2
dataprovider/user.go

@@ -853,9 +853,9 @@ func (u *User) isFilePatternAllowed(virtualPath string) bool {
 	return true
 }
 
-// CanManahePublicKeys return true if this user is allowed to manage public keys
+// CanManagePublicKeys return true if this user is allowed to manage public keys
 // from the web client
-func (u *User) CanManahePublicKeys() bool {
+func (u *User) CanManagePublicKeys() bool {
 	return !utils.IsStringInSlice(WebClientPubKeyChangeDisabled, u.Filters.WebClient)
 }
 

+ 4 - 3
httpd/httpd.go

@@ -79,9 +79,10 @@ const (
 	webClientLogoutPathDefault      = "/web/client/logout"
 	webStaticFilesPathDefault       = "/static"
 	// MaxRestoreSize defines the max size for the loaddata input file
-	MaxRestoreSize = 10485760 // 10 MB
-	maxRequestSize = 1048576  // 1MB
-	osWindows      = "windows"
+	MaxRestoreSize   = 10485760 // 10 MB
+	maxRequestSize   = 1048576  // 1MB
+	maxLoginPostSize = 262144   // 256 KB
+	osWindows        = "windows"
 )
 
 var (

+ 5 - 7
httpd/server.go

@@ -48,14 +48,12 @@ func (s *httpdServer) listenAndServe() error {
 	httpServer := &http.Server{
 		Handler:           s.router,
 		ReadHeaderTimeout: 30 * time.Second,
-		IdleTimeout:       120 * time.Second,
+		ReadTimeout:       60 * time.Second,
+		WriteTimeout:      60 * time.Second,
+		IdleTimeout:       60 * time.Second,
 		MaxHeaderBytes:    1 << 16, // 64KB
 		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
 	}
-	if !s.binding.EnableWebClient {
-		httpServer.ReadTimeout = 60 * time.Second
-		httpServer.WriteTimeout = 90 * time.Second
-	}
 	if certMgr != nil && s.binding.EnableHTTPS {
 		config := &tls.Config{
 			GetCertificate:           certMgr.GetCertificateFunc(),
@@ -111,7 +109,7 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
 }
 
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
-	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
+	r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
 	common.Connections.AddNetworkConnection()
 	defer common.Connections.RemoveNetworkConnection()
 
@@ -185,7 +183,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 }
 
 func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Request) {
-	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
+	r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
 	if err := r.ParseForm(); err != nil {
 		renderLoginPage(w, err.Error())
 		return

+ 1 - 7
httpd/webclient.go

@@ -582,10 +582,7 @@ func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) {
 			// we have something like -500
 			start = size - end
 			size = end
-			// this can't happen, we did end = size -1 above
-			/*if start < 0 {
-				return 0, 0, fmt.Errorf("unacceptable range %#v", bytesRange)
-			}*/
+			// start cannit be < 0 here, we did end = size -1 above
 		} else {
 			// we have something like 500-600
 			size = end - start + 1
@@ -595,9 +592,6 @@ func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) {
 		}
 		return start, size, nil
 	}
-	/*if start == -1 {
-		return 0, 0, fmt.Errorf("unacceptable range %#v", bytesRange)
-	}*/
 	// we have something like 500-
 	size -= start
 	if size < 0 {

+ 7 - 6
telemetry/telemetry.go

@@ -89,12 +89,13 @@ func (c Conf) Initialize(configDir string) error {
 	certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir)
 	initializeRouter(c.EnableProfiler)
 	httpServer := &http.Server{
-		Handler:        router,
-		ReadTimeout:    60 * time.Second,
-		WriteTimeout:   60 * time.Second,
-		IdleTimeout:    120 * time.Second,
-		MaxHeaderBytes: 1 << 14, // 16KB
-		ErrorLog:       log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
+		Handler:           router,
+		ReadHeaderTimeout: 30 * time.Second,
+		ReadTimeout:       60 * time.Second,
+		WriteTimeout:      60 * time.Second,
+		IdleTimeout:       60 * time.Second,
+		MaxHeaderBytes:    1 << 14, // 16KB
+		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
 	}
 	if certificateFile != "" && certificateKeyFile != "" {
 		certMgr, err = common.NewCertManager(certificateFile, certificateKeyFile, configDir, logSender)

+ 1 - 1
templates/webclient/credentials.html

@@ -41,7 +41,7 @@
         </form>
     </div>
 </div>
-{{if .LoggedUser.CanManahePublicKeys}}
+{{if .LoggedUser.CanManagePublicKeys}}
 <div class="card shadow mb-4">
     <div class="card-header py-3">
         <h6 class="m-0 font-weight-bold text-primary">Manage public keys</h6>

+ 1 - 1
templates/webclient/files.html

@@ -94,7 +94,7 @@
             "scrollY": false,
             "responsive": true,
             "language": {
-                "emptyTable": "You have no files or folders"
+                "emptyTable": "No files or folders"
             },
             /*"select": {
                 "style":    'single',

+ 86 - 0
utils/timeoutlistener.go

@@ -0,0 +1,86 @@
+package utils
+
+import (
+	"net"
+	"sync/atomic"
+	"time"
+)
+
+type listener struct {
+	net.Listener
+	ReadTimeout  time.Duration
+	WriteTimeout time.Duration
+}
+
+func (l *listener) Accept() (net.Conn, error) {
+	c, err := l.Listener.Accept()
+	if err != nil {
+		return nil, err
+	}
+	tc := &Conn{
+		Conn:                     c,
+		ReadTimeout:              l.ReadTimeout,
+		WriteTimeout:             l.WriteTimeout,
+		ReadThreshold:            int32((l.ReadTimeout * 1024) / time.Second),
+		WriteThreshold:           int32((l.WriteTimeout * 1024) / time.Second),
+		BytesReadFromDeadline:    0,
+		BytesWrittenFromDeadline: 0,
+	}
+	return tc, nil
+}
+
+// Conn wraps a net.Conn, and sets a deadline for every read
+// and write operation.
+type Conn struct {
+	net.Conn
+	ReadTimeout              time.Duration
+	WriteTimeout             time.Duration
+	ReadThreshold            int32
+	WriteThreshold           int32
+	BytesReadFromDeadline    int32
+	BytesWrittenFromDeadline int32
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+	if atomic.LoadInt32(&c.BytesReadFromDeadline) > c.ReadThreshold {
+		atomic.StoreInt32(&c.BytesReadFromDeadline, 0)
+		// we set both read and write deadlines here otherwise after the request
+		// is read writing the response fails with an i/o timeout error
+		err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout))
+		if err != nil {
+			return 0, err
+		}
+	}
+	n, err = c.Conn.Read(b)
+	atomic.AddInt32(&c.BytesReadFromDeadline, int32(n))
+	return
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+	if atomic.LoadInt32(&c.BytesWrittenFromDeadline) > c.WriteThreshold {
+		atomic.StoreInt32(&c.BytesWrittenFromDeadline, 0)
+		// we extend the read deadline too, not sure it's necessary,
+		// but it doesn't hurt
+		err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout))
+		if err != nil {
+			return
+		}
+	}
+	n, err = c.Conn.Write(b)
+	atomic.AddInt32(&c.BytesWrittenFromDeadline, int32(n))
+	return
+}
+
+func newListener(network, addr string, readTimeout, writeTimeout time.Duration) (net.Listener, error) {
+	l, err := net.Listen(network, addr)
+	if err != nil {
+		return nil, err
+	}
+
+	tl := &listener{
+		Listener:     l,
+		ReadTimeout:  readTimeout,
+		WriteTimeout: writeTimeout,
+	}
+	return tl, nil
+}

+ 2 - 3
utils/utils.go

@@ -426,11 +426,10 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool,
 			logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err)
 		}
 		os.Remove(address)
-
-		listener, err = net.Listen("unix", address)
+		listener, err = newListener("unix", address, srv.ReadTimeout, srv.WriteTimeout)
 	} else {
 		CheckTCP4Port(port)
-		listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port))
+		listener, err = newListener("tcp", fmt.Sprintf("%s:%d", address, port), srv.ReadTimeout, srv.WriteTimeout)
 	}
 	if err != nil {
 		return err

+ 5 - 8
webdavd/server.go

@@ -40,9 +40,10 @@ type webDavServer struct {
 func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error {
 	handler := compressor.Handler(s)
 	httpServer := &http.Server{
-		Addr:              s.binding.GetAddress(),
 		ReadHeaderTimeout: 30 * time.Second,
-		IdleTimeout:       120 * time.Second,
+		ReadTimeout:       60 * time.Second,
+		WriteTimeout:      60 * time.Second,
+		IdleTimeout:       60 * time.Second,
 		MaxHeaderBytes:    1 << 16, // 64KB
 		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
 	}
@@ -79,15 +80,11 @@ func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error {
 				httpServer.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven
 			}
 		}
-		logger.Info(logSender, "", "starting HTTPS serving, binding: %v", s.binding.GetAddress())
-		utils.CheckTCP4Port(s.binding.Port)
-		return httpServer.ListenAndServeTLS("", "")
+		return utils.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, logSender)
 	}
 	s.binding.EnableHTTPS = false
 	serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding)
-	logger.Info(logSender, "", "starting HTTP serving, binding: %v", s.binding.GetAddress())
-	utils.CheckTCP4Port(s.binding.Port)
-	return httpServer.ListenAndServe()
+	return utils.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, logSender)
 }
 
 func (s *webDavServer) verifyTLSConnection(state tls.ConnectionState) error {

+ 14 - 5
webdavd/webdavd_test.go

@@ -2348,6 +2348,17 @@ func checkBasicFunc(client *gowebdav.Client) error {
 	return err
 }
 
+func checkFileSize(remoteDestPath string, expectedSize int64, client *gowebdav.Client) error {
+	info, err := client.Stat(remoteDestPath)
+	if err != nil {
+		return err
+	}
+	if info.Size() != expectedSize {
+		return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", info.Size(), expectedSize)
+	}
+	return nil
+}
+
 func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *gowebdav.Client) error {
 	srcFile, err := os.Open(localSourcePath)
 	if err != nil {
@@ -2359,12 +2370,10 @@ func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int6
 		return err
 	}
 	if expectedSize > 0 {
-		info, err := client.Stat(remoteDestPath)
+		err = checkFileSize(remoteDestPath, expectedSize, client)
 		if err != nil {
-			return err
-		}
-		if info.Size() != expectedSize {
-			return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", info.Size(), expectedSize)
+			time.Sleep(1 * time.Second)
+			return checkFileSize(remoteDestPath, expectedSize, client)
 		}
 	}
 	return nil