Browse Source

drive: rewrite LOCK paths

Fixes #12097

Signed-off-by: Percy Wegmann <[email protected]>
Percy Wegmann 1 year ago
parent
commit
59848fe14b

+ 11 - 5
drive/driveimpl/compositedav/compositedav.go

@@ -93,8 +93,15 @@ var cacheInvalidatingMethods = map[string]bool{
 
 // ServeHTTP implements http.Handler.
 func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method == "PROPFIND" {
-		h.handlePROPFIND(w, r)
+	pathComponents := shared.CleanAndSplit(r.URL.Path)
+	mpl := h.maxPathLength(r)
+
+	switch r.Method {
+	case "PROPFIND":
+		h.handlePROPFIND(w, r, pathComponents, mpl)
+		return
+	case "LOCK":
+		h.handleLOCK(w, r, pathComponents, mpl)
 		return
 	}
 
@@ -107,9 +114,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		h.StatCache.invalidate()
 	}
 
-	mpl := h.maxPathLength(r)
-	pathComponents := shared.CleanAndSplit(r.URL.Path)
-
 	if len(pathComponents) >= mpl {
 		h.delegate(mpl, pathComponents[mpl-1:], w, r)
 		return
@@ -141,6 +145,8 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
 
 // delegate sends the request to the Child WebDAV server.
 func (h *Handler) delegate(mpl int, pathComponents []string, w http.ResponseWriter, r *http.Request) {
+	rewriteIfHeader(r, pathComponents, mpl)
+
 	dest := r.Header.Get("Destination")
 	if dest != "" {
 		// Rewrite destination header

+ 0 - 77
drive/driveimpl/compositedav/propfind.go

@@ -1,77 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package compositedav
-
-import (
-	"bytes"
-	"fmt"
-	"math"
-	"net/http"
-	"regexp"
-
-	"tailscale.com/drive/driveimpl/shared"
-)
-
-var (
-	hrefRegex = regexp.MustCompile(`(?s)<D:href>/?([^<]*)/?</D:href>`)
-)
-
-func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) {
-	pathComponents := shared.CleanAndSplit(r.URL.Path)
-	mpl := h.maxPathLength(r)
-	if !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl {
-		// Delegate to a Child.
-		depth := getDepth(r)
-
-		status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) {
-			// Use a buffering ResponseWriter so that we can manipulate the result.
-			// The only thing we use from the original ResponseWriter is Header().
-			bw := &bufferingResponseWriter{ResponseWriter: w}
-
-			mpl := h.maxPathLength(r)
-			h.delegate(mpl, pathComponents[mpl-1:], bw, r)
-
-			// Fixup paths to add the requested path as a prefix.
-			pathPrefix := shared.Join(pathComponents[0:mpl]...)
-			b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("<D:href>%s/$1</D:href>", pathPrefix)))
-
-			return bw.status, b
-		})
-
-		w.Header().Del("Content-Length")
-		w.WriteHeader(status)
-		if result != nil {
-			w.Write(result)
-		}
-		return
-	}
-
-	h.handle(w, r)
-}
-
-func getDepth(r *http.Request) int {
-	switch r.Header.Get("Depth") {
-	case "0":
-		return 0
-	case "1":
-		return 1
-	case "infinity":
-		return math.MaxInt
-	}
-	return 0
-}
-
-type bufferingResponseWriter struct {
-	http.ResponseWriter
-	status int
-	buf    bytes.Buffer
-}
-
-func (bw *bufferingResponseWriter) WriteHeader(statusCode int) {
-	bw.status = statusCode
-}
-
-func (bw *bufferingResponseWriter) Write(p []byte) (int, error) {
-	return bw.buf.Write(p)
-}

+ 122 - 0
drive/driveimpl/compositedav/rewriting.go

@@ -0,0 +1,122 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package compositedav
+
+import (
+	"bytes"
+	"fmt"
+	"math"
+	"net/http"
+	"regexp"
+	"strings"
+
+	"tailscale.com/drive/driveimpl/shared"
+)
+
+var (
+	responseHrefRegex = regexp.MustCompile(`(?s)(<D:(response|lockroot)>)<D:href>/?([^<]*)/?</D:href>`)
+	ifHrefRegex       = regexp.MustCompile(`^<(https?://[^/]+)?([^>]+)>`)
+)
+
+func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) {
+	if shouldDelegateToChild(r, pathComponents, mpl) {
+		// Delegate to a Child.
+		depth := getDepth(r)
+
+		status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) {
+			return h.delegateRewriting(w, r, pathComponents, mpl)
+		})
+
+		respondRewritten(w, status, result)
+		return
+	}
+
+	h.handle(w, r)
+}
+
+func (h *Handler) handleLOCK(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) {
+	if shouldDelegateToChild(r, pathComponents, mpl) {
+		// Delegate to a Child.
+		status, result := h.delegateRewriting(w, r, pathComponents, mpl)
+		respondRewritten(w, status, result)
+		return
+	}
+
+	http.Error(w, "locking of top level directories is not allowed", http.StatusMethodNotAllowed)
+}
+
+// shouldDelegateToChild decides whether a request should be delegated to a
+// child filesystem, as opposed to being handled by this filesystem. It checks
+// the depth of the requested path, and if it's deeper than the portion of the
+// tree that's handled by the parent, returns true.
+func shouldDelegateToChild(r *http.Request, pathComponents []string, mpl int) bool {
+	return !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl
+}
+
+func (h *Handler) delegateRewriting(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) (int, []byte) {
+	// Use a buffering ResponseWriter so that we can manipulate the result.
+	// The only thing we use from the original ResponseWriter is Header().
+	bw := &bufferingResponseWriter{ResponseWriter: w}
+
+	h.delegate(mpl, pathComponents[mpl-1:], bw, r)
+
+	// Fixup paths to add the requested path as a prefix, escaped for inclusion in XML.
+	pp := shared.EscapeForXML(shared.Join(pathComponents[0:mpl]...))
+	b := responseHrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("$1<D:href>%s/$3</D:href>", pp)))
+	return bw.status, b
+}
+
+func respondRewritten(w http.ResponseWriter, status int, result []byte) {
+	w.Header().Del("Content-Length")
+	w.WriteHeader(status)
+	if result != nil {
+		w.Write(result)
+	}
+}
+
+func getDepth(r *http.Request) int {
+	switch r.Header.Get("Depth") {
+	case "0":
+		return 0
+	case "1":
+		return 1
+	case "infinity":
+		return math.MaxInt16 // a really large number, but not infinity (avoids wrapping when we do arithmetic with this)
+	}
+	return 0
+}
+
+type bufferingResponseWriter struct {
+	http.ResponseWriter
+	status int
+	buf    bytes.Buffer
+}
+
+func (bw *bufferingResponseWriter) WriteHeader(statusCode int) {
+	bw.status = statusCode
+}
+
+func (bw *bufferingResponseWriter) Write(p []byte) (int, error) {
+	return bw.buf.Write(p)
+}
+
+// rewriteIfHeader rewrites URLs in the If header by removing the host and the
+// portion of the path that corresponds to this composite filesystem. This way,
+// when we delegate requests to child filesystems, the If header will reference
+// a path that makes sense on those filesystems.
+//
+// See http://www.webdav.org/specs/rfc4918.html#HEADER_If
+func rewriteIfHeader(r *http.Request, pathComponents []string, mpl int) {
+	ih := r.Header.Get("If")
+	if ih == "" {
+		return
+	}
+	matches := ifHrefRegex.FindStringSubmatch(ih)
+	if len(matches) == 3 {
+		pp := shared.JoinEscaped(pathComponents[0:mpl]...)
+		p := strings.Replace(shared.JoinEscaped(pathComponents...), pp, "", 1)
+		nih := ifHrefRegex.ReplaceAllString(ih, fmt.Sprintf("<%s>", p))
+		r.Header.Set("If", nih)
+	}
+}

+ 228 - 5
drive/driveimpl/drive_test.go

@@ -14,6 +14,8 @@ import (
 	"os"
 	"path"
 	"path/filepath"
+	"regexp"
+	"runtime"
 	"slices"
 	"strings"
 	"sync"
@@ -30,14 +32,29 @@ import (
 const (
 	domain = `test$%domain.com`
 
-	remote1 = `rem ote$%1`
-	remote2 = `_rem ote$%2`
-	share11 = `sha re$%11`
-	share12 = `_sha re$%12`
-	file111 = `fi le$%111.txt`
+	remote1 = `rem ote$%<>1`
+	remote2 = `_rem ote$%<>2`
+	share11 = `sha re$%<>11`
+	share12 = `_sha re$%<>12`
 	file112 = `file112.txt`
 )
 
+var (
+	file111 = `fi le$%<>111.txt`
+)
+
+func init() {
+	if runtime.GOOS == "windows" {
+		// file with less than and greater than doesn't work on Windows
+		file111 = `fi le$%111.txt`
+	}
+}
+
+var (
+	lockRootRegex  = regexp.MustCompile(`<D:lockroot><D:href>/?([^<]*)/?</D:href>`)
+	lockTokenRegex = regexp.MustCompile(`<D:locktoken><D:href>([0-9]+)/?</D:href>`)
+)
+
 func init() {
 	// set AllowShareAs() to false so that we don't try to use sub-processes
 	// for access files on disk.
@@ -145,6 +162,206 @@ func TestSecretTokenAuth(t *testing.T) {
 	}
 }
 
+func TestLOCK(t *testing.T) {
+	s := newSystem(t)
+
+	s.addRemote(remote1)
+	s.addShare(remote1, share11, drive.PermissionReadWrite)
+	s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true)
+
+	client := &http.Client{
+		Transport: &http.Transport{DisableKeepAlives: true},
+	}
+
+	u := fmt.Sprintf("http://%s/%s/%s/%s/%s",
+		s.local.l.Addr(),
+		url.PathEscape(domain),
+		url.PathEscape(remote1),
+		url.PathEscape(share11),
+		url.PathEscape(file111))
+
+	// First acquire a lock with a short timeout
+	req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody))
+	if err != nil {
+		t.Fatal(err)
+	}
+	req.Header.Set("Depth", "infinity")
+	req.Header.Set("Timeout", "Second-1")
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 200 {
+		t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode)
+	}
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	submatches := lockRootRegex.FindStringSubmatch(string(body))
+	if len(submatches) != 2 {
+		t.Fatal("failed to find lockroot")
+	}
+	want := shared.EscapeForXML(pathTo(remote1, share11, file111))
+	got := submatches[1]
+	if got != want {
+		t.Fatalf("want lockroot %q, got %q", want, got)
+	}
+
+	submatches = lockTokenRegex.FindStringSubmatch(string(body))
+	if len(submatches) != 2 {
+		t.Fatal("failed to find locktoken")
+	}
+	lockToken := submatches[1]
+	ifHeader := fmt.Sprintf("<%s> (<%s>)", u, lockToken)
+
+	// Then refresh the lock with a longer timeout
+	req, err = http.NewRequest("LOCK", u, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	req.Header.Set("Depth", "infinity")
+	req.Header.Set("Timeout", "Second-600")
+	req.Header.Set("If", ifHeader)
+	resp, err = client.Do(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 200 {
+		t.Fatalf("expected LOCK refresh to succeed, but got status %d", resp.StatusCode)
+	}
+	body, err = io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	submatches = lockRootRegex.FindStringSubmatch(string(body))
+	if len(submatches) != 2 {
+		t.Fatal("failed to find lockroot after refresh")
+	}
+	want = shared.EscapeForXML(pathTo(remote1, share11, file111))
+	got = submatches[1]
+	if got != want {
+		t.Fatalf("want lockroot after refresh %q, got %q", want, got)
+	}
+
+	submatches = lockTokenRegex.FindStringSubmatch(string(body))
+	if len(submatches) != 2 {
+		t.Fatal("failed to find locktoken after refresh")
+	}
+	if submatches[1] != lockToken {
+		t.Fatalf("on refresh, lock token changed from %q to %q", lockToken, submatches[1])
+	}
+
+	// Then wait past the original timeout, then try to delete without the lock
+	// (should fail)
+	time.Sleep(1 * time.Second)
+	req, err = http.NewRequest("DELETE", u, nil)
+	if err != nil {
+		log.Fatal(err)
+	}
+	resp, err = client.Do(req)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 423 {
+		t.Fatalf("deleting without lock token should fail with 423, but got %d", resp.StatusCode)
+	}
+
+	// Then delete with the lock (should succeed)
+	req, err = http.NewRequest("DELETE", u, nil)
+	if err != nil {
+		log.Fatal(err)
+	}
+	req.Header.Set("If", ifHeader)
+	resp, err = client.Do(req)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 204 {
+		t.Fatalf("deleting with lock token should have succeeded with 204, but got %d", resp.StatusCode)
+	}
+}
+
+func TestUNLOCK(t *testing.T) {
+	s := newSystem(t)
+
+	s.addRemote(remote1)
+	s.addShare(remote1, share11, drive.PermissionReadWrite)
+	s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true)
+
+	client := &http.Client{
+		Transport: &http.Transport{DisableKeepAlives: true},
+	}
+
+	u := fmt.Sprintf("http://%s/%s/%s/%s/%s",
+		s.local.l.Addr(),
+		url.PathEscape(domain),
+		url.PathEscape(remote1),
+		url.PathEscape(share11),
+		url.PathEscape(file111))
+
+	// Acquire a lock
+	req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody))
+	if err != nil {
+		t.Fatal(err)
+	}
+	req.Header.Set("Depth", "infinity")
+	req.Header.Set("Timeout", "Second-600")
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 200 {
+		t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode)
+	}
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	submatches := lockTokenRegex.FindStringSubmatch(string(body))
+	if len(submatches) != 2 {
+		t.Fatal("failed to find locktoken")
+	}
+	lockToken := submatches[1]
+
+	// Release the lock
+	req, err = http.NewRequest("UNLOCK", u, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	req.Header.Set("Lock-Token", fmt.Sprintf("<%s>", lockToken))
+	resp, err = client.Do(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 204 {
+		t.Fatalf("expected UNLOCK to succeed with a 204, but got status %d", resp.StatusCode)
+	}
+
+	// Then delete without the lock (should succeed)
+	req, err = http.NewRequest("DELETE", u, nil)
+	if err != nil {
+		log.Fatal(err)
+	}
+	resp, err = client.Do(req)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != 204 {
+		t.Fatalf("deleting without lock should have succeeded with 204, but got %d", resp.StatusCode)
+	}
+}
+
 type local struct {
 	l  net.Listener
 	fs *FileSystemForLocal
@@ -486,3 +703,9 @@ func (a *noopAuthenticator) Clone() gowebdav.Authenticator {
 func (a *noopAuthenticator) Close() error {
 	return nil
 }
+
+const lockBody = `<?xml version="1.0" encoding="utf-8" ?> 
+<D:lockinfo xmlns:D='DAV:'> 
+  <D:lockscope><D:exclusive/></D:lockscope> 
+  <D:locktype><D:write/></D:locktype> 
+</D:lockinfo>`

+ 3 - 0
drive/driveimpl/fileserver.go

@@ -151,6 +151,9 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusNotFound)
 		return
 	}
+	// WebDAV's locking code compares the lock resources with the request's
+	// host header, set this to empty to avoid mismatches.
+	r.Host = ""
 	h.ServeHTTP(w, r)
 }
 

+ 16 - 0
drive/driveimpl/shared/xml.go

@@ -0,0 +1,16 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package shared
+
+import (
+	"bytes"
+	"encoding/xml"
+)
+
+// EscapeForXML escapes the given string for use in XML text.
+func EscapeForXML(s string) string {
+	result := bytes.NewBuffer(nil)
+	xml.Escape(result, []byte(s))
+	return result.String()
+}