Sfoglia il codice sorgente

HTTPD, WebDAV: use http.ResponseController

backport from Enterprise edition

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 2 mesi fa
parent
commit
ddbe40cefa

+ 2 - 5
internal/httpd/api_http_user.go

@@ -53,11 +53,8 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return nil, err
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
 		return connection, err

+ 2 - 5
internal/httpd/api_shares.go

@@ -532,11 +532,8 @@ func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, v
 		return share, nil, err
 	}
 	connID := xid.New().String()
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 
 	return share, connection, nil
 }

+ 9 - 0
internal/httpd/api_utils.go

@@ -946,3 +946,12 @@ func hideConfidentialData(claims *jwtTokenClaims, r *http.Request) bool {
 	}
 	return r.URL.Query().Get("confidential_data") != "1"
 }
+
+func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) {
+	if err := rc.SetReadDeadline(read); err != nil {
+		logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err)
+	}
+	if err := rc.SetWriteDeadline(write); err != nil {
+		logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err)
+	}
+}

+ 14 - 0
internal/httpd/handler.go

@@ -35,6 +35,17 @@ import (
 type Connection struct {
 	*common.BaseConnection
 	request *http.Request
+	rc      *http.ResponseController
+}
+
+func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection {
+	rc := http.NewResponseController(w)
+	responseControllerDeadlines(rc, time.Time{}, time.Time{})
+	return &Connection{
+		BaseConnection: conn,
+		request:        r,
+		rc:             rc,
+	}
 }
 
 // GetClientVersion returns the connected client's version.
@@ -60,6 +71,9 @@ func (c *Connection) GetRemoteAddress() string {
 
 // Disconnect closes the active transfer
 func (c *Connection) Disconnect() (err error) {
+	if c.rc != nil {
+		responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second))
+	}
 	return c.SignalTransfersAbort()
 }
 

+ 30 - 23
internal/httpd/internal_test.go

@@ -2686,10 +2686,11 @@ func TestCompressorAbortHandler(t *testing.T) {
 		assert.Equal(t, http.ErrAbortHandler, rcv)
 	}()
 
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}),
-		request:        nil,
-	}
+	connection := newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}),
+		nil,
+		nil,
+	)
 	share := &dataprovider.Share{}
 	renderCompressedFiles(&failingWriter{}, connection, "", nil, share)
 }
@@ -2711,10 +2712,11 @@ func TestZipErrors(t *testing.T) {
 	}
 	user.Permissions = make(map[string][]string)
 	user.Permissions["/"] = []string{dataprovider.PermAny}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
-		request:        nil,
-	}
+	connection := newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
+		nil,
+		nil,
+	)
 
 	testDir := filepath.Join(os.TempDir(), "testDir")
 	err := os.MkdirAll(testDir, os.ModePerm)
@@ -2935,10 +2937,11 @@ func TestConnection(t *testing.T) {
 	}
 	user.Permissions = make(map[string][]string)
 	user.Permissions["/"] = []string{dataprovider.PermAny}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
-		request:        nil,
-	}
+	connection := newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
+		nil,
+		nil,
+	)
 	assert.Empty(t, connection.GetClientVersion())
 	assert.Empty(t, connection.GetRemoteAddress())
 	assert.Empty(t, connection.GetCommand())
@@ -2959,10 +2962,11 @@ func TestGetFileWriterErrors(t *testing.T) {
 	}
 	user.Permissions = make(map[string][]string)
 	user.Permissions["/"] = []string{dataprovider.PermAny}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
-		request:        nil,
-	}
+	connection := newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
+		nil,
+		nil,
+	)
 	_, err := connection.getFileWriter("name")
 	assert.Error(t, err)
 
@@ -2975,10 +2979,11 @@ func TestGetFileWriterErrors(t *testing.T) {
 		},
 		AccessSecret: kms.NewPlainSecret("secret"),
 	}
-	connection = &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
-		request:        nil,
-	}
+	connection = newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
+		nil,
+		nil,
+	)
 	_, err = connection.getFileWriter("/path")
 	assert.Error(t, err)
 }
@@ -3007,9 +3012,11 @@ func TestHTTPDFile(t *testing.T) {
 	}
 	user.Permissions = make(map[string][]string)
 	user.Permissions["/"] = []string{dataprovider.PermAny}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
-	}
+	connection := newConnection(
+		common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user),
+		nil,
+		nil,
+	)
 
 	fs, err := user.GetFilesystem("")
 	assert.NoError(t, err)

+ 5 - 2
internal/httpd/server.go

@@ -103,8 +103,6 @@ func (s *httpdServer) listenAndServe() error {
 	httpServer := &http.Server{
 		Handler:           s.router,
 		ReadHeaderTimeout: 30 * time.Second,
-		ReadTimeout:       60 * time.Second,
-		WriteTimeout:      60 * time.Second,
 		IdleTimeout:       60 * time.Second,
 		MaxHeaderBytes:    1 << 16, // 64KB
 		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
@@ -1087,6 +1085,11 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
 
 func (s *httpdServer) parseHeaders(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		responseControllerDeadlines(
+			http.NewResponseController(w),
+			time.Now().Add(60*time.Second),
+			time.Now().Add(60*time.Second),
+		)
 		w.Header().Set("Server", version.GetServerVersion("/", false))
 		ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
 		var ip net.IP

+ 10 - 25
internal/httpd/webclient.go

@@ -906,11 +906,8 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
 		s.renderClientForbiddenPage(w, r, err)
 		return
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests,
 			util.NewI18nError(err, util.I18nError429Message), "")
@@ -1197,11 +1194,8 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
 		sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirList403), http.StatusForbidden)
 		return
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		sendAPIResponse(w, r, err, util.I18nErrorDirList429, http.StatusTooManyRequests)
 		return
@@ -1287,11 +1281,8 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
 		s.renderClientForbiddenPage(w, r, err)
 		return
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests,
 			util.NewI18nError(err, util.I18nError429Message), "")
@@ -1348,11 +1339,8 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
 		s.renderClientForbiddenPage(w, r, err)
 		return
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests,
 			util.NewI18nError(err, util.I18nError429Message), "")
@@ -1844,11 +1832,8 @@ func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request)
 		s.renderClientForbiddenPage(w, r, err)
 		return
 	}
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests,
 			util.NewI18nError(err, util.I18nError429Message), "")

+ 0 - 100
internal/util/timeoutlistener.go

@@ -1,100 +0,0 @@
-// Copyright (C) 2019 Nicola Murino
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published
-// by the Free Software Foundation, version 3.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-package util
-
-import (
-	"net"
-	"sync/atomic"
-	"time"
-)
-
-type listener struct {
-	net.Listener
-	ReadTimeout  time.Duration
-	WriteTimeout time.Duration
-}
-
-func (l *listener) Accept() (net.Conn, error) {
-	c, err := l.Listener.Accept()
-	if err != nil {
-		return nil, err
-	}
-	tc := &Conn{
-		Conn:           c,
-		ReadTimeout:    l.ReadTimeout,
-		WriteTimeout:   l.WriteTimeout,
-		ReadThreshold:  int32((l.ReadTimeout * 1024) / time.Second),
-		WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
-	}
-	tc.BytesReadFromDeadline.Store(0)
-	tc.BytesWrittenFromDeadline.Store(0)
-	return tc, nil
-}
-
-// Conn wraps a net.Conn, and sets a deadline for every read
-// and write operation.
-type Conn struct {
-	net.Conn
-	ReadTimeout              time.Duration
-	WriteTimeout             time.Duration
-	ReadThreshold            int32
-	WriteThreshold           int32
-	BytesReadFromDeadline    atomic.Int32
-	BytesWrittenFromDeadline atomic.Int32
-}
-
-func (c *Conn) Read(b []byte) (n int, err error) {
-	if c.BytesReadFromDeadline.Load() > c.ReadThreshold {
-		c.BytesReadFromDeadline.Store(0)
-		// we set both read and write deadlines here otherwise after the request
-		// is read writing the response fails with an i/o timeout error
-		err = c.SetDeadline(time.Now().Add(c.ReadTimeout))
-		if err != nil {
-			return 0, err
-		}
-	}
-	n, err = c.Conn.Read(b)
-	c.BytesReadFromDeadline.Add(int32(n))
-	return
-}
-
-func (c *Conn) Write(b []byte) (n int, err error) {
-	if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold {
-		c.BytesWrittenFromDeadline.Store(0)
-		// we extend the read deadline too, not sure it's necessary,
-		// but it doesn't hurt
-		err = c.SetDeadline(time.Now().Add(c.WriteTimeout))
-		if err != nil {
-			return
-		}
-	}
-	n, err = c.Conn.Write(b)
-	c.BytesWrittenFromDeadline.Add(int32(n))
-	return
-}
-
-func newListener(network, addr string, readTimeout, writeTimeout time.Duration) (net.Listener, error) {
-	l, err := net.Listen(network, addr)
-	if err != nil {
-		return nil, err
-	}
-
-	tl := &listener{
-		Listener:     l,
-		ReadTimeout:  readTimeout,
-		WriteTimeout: writeTimeout,
-	}
-	return tl, nil
-}

+ 2 - 2
internal/util/util.go

@@ -593,7 +593,7 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool,
 			logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err)
 		}
 		os.Remove(address)
-		listener, err = newListener("unix", address, srv.ReadTimeout, srv.WriteTimeout)
+		listener, err = net.Listen("unix", address)
 		if err == nil {
 			// should a chmod err be fatal?
 			if errChmod := os.Chmod(address, 0770); errChmod != nil {
@@ -602,7 +602,7 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool,
 		}
 	} else {
 		CheckTCP4Port(port)
-		listener, err = newListener("tcp", fmt.Sprintf("%s:%d", address, port), srv.ReadTimeout, srv.WriteTimeout)
+		listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port))
 	}
 	if err != nil {
 		return err

+ 14 - 0
internal/webdavd/handler.go

@@ -36,6 +36,17 @@ import (
 type Connection struct {
 	*common.BaseConnection
 	request *http.Request
+	rc      *http.ResponseController
+}
+
+func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection {
+	rc := http.NewResponseController(w)
+	responseControllerDeadlines(rc, time.Time{}, time.Time{})
+	return &Connection{
+		BaseConnection: conn,
+		request:        r,
+		rc:             rc,
+	}
 }
 
 func (c *Connection) getModificationTime() time.Time {
@@ -73,6 +84,9 @@ func (c *Connection) GetRemoteAddress() string {
 
 // Disconnect closes the active transfer
 func (c *Connection) Disconnect() error {
+	if c.rc != nil {
+		responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second))
+	}
 	return c.SignalTransfersAbort()
 }
 

+ 17 - 7
internal/webdavd/server.go

@@ -55,8 +55,6 @@ func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error {
 	handler := compressor.Handler(s)
 	httpServer := &http.Server{
 		ReadHeaderTimeout: 30 * time.Second,
-		ReadTimeout:       60 * time.Second,
-		WriteTimeout:      60 * time.Second,
 		IdleTimeout:       60 * time.Second,
 		MaxHeaderBytes:    1 << 16, // 64KB
 		ErrorLog:          log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0),
@@ -170,6 +168,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 	}()
 
+	responseControllerDeadlines(
+		http.NewResponseController(w),
+		time.Now().Add(60*time.Second),
+		time.Now().Add(60*time.Second),
+	)
 	w.Header().Set("Server", version.GetServerVersion("/", false))
 	ipAddr := s.checkRemoteAddress(r)
 
@@ -228,11 +231,9 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	connection := &Connection{
-		BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r),
-			r.RemoteAddr, user),
-		request: r,
-	}
+	baseConn := common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r),
+		r.RemoteAddr, user)
+	connection := newConnection(baseConn, w, r)
 	if err = common.Connections.Add(connection); err != nil {
 		errClose := user.CloseFs()
 		logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose)
@@ -389,6 +390,15 @@ func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
 	return ipAddr
 }
 
+func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) {
+	if err := rc.SetReadDeadline(read); err != nil {
+		logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err)
+	}
+	if err := rc.SetWriteDeadline(write); err != nil {
+		logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err)
+	}
+}
+
 func writeLog(r *http.Request, status int, err error) {
 	scheme := "http"
 	cipherSuite := ""