浏览代码

lib/fs: Treat Windows junctions as normal directories (#6606)

Fixes #1830, presumably.
xarx00 5 年之前
父节点
当前提交
ee445e35a0

+ 1 - 1
lib/fs/basicfs.go

@@ -297,7 +297,7 @@ func (f *BasicFilesystem) SameFile(fi1, fi2 FileInfo) bool {
 		return false
 	}
 
-	return os.SameFile(f1.FileInfo, f2.FileInfo)
+	return os.SameFile(f1.osFileInfo(), f2.osFileInfo())
 }
 
 // basicFile implements the fs.File interface on top of an os.File

+ 10 - 1
lib/fs/basicfs_fileinfo_unix.go

@@ -8,7 +8,10 @@
 
 package fs
 
-import "syscall"
+import (
+	"os"
+	"syscall"
+)
 
 func (e basicFileInfo) Mode() FileMode {
 	return FileMode(e.FileInfo.Mode())
@@ -27,3 +30,9 @@ func (e basicFileInfo) Group() int {
 	}
 	return -1
 }
+
+// fileStat converts e to os.FileInfo that is suitable
+// to be passed to os.SameFile. Non-trivial on Windows.
+func (e *basicFileInfo) osFileInfo() os.FileInfo {
+	return e.FileInfo
+}

+ 10 - 0
lib/fs/basicfs_fileinfo_windows.go

@@ -56,3 +56,13 @@ func (e basicFileInfo) Owner() int {
 func (e basicFileInfo) Group() int {
 	return -1
 }
+
+// osFileInfo converts e to os.FileInfo that is suitable
+// to be passed to os.SameFile.
+func (e *basicFileInfo) osFileInfo() os.FileInfo {
+	fi := e.FileInfo
+	if fi, ok := fi.(*dirJunctFileInfo); ok {
+		return fi.FileInfo
+	}
+	return fi
+}

+ 12 - 0
lib/fs/basicfs_test.go

@@ -577,3 +577,15 @@ func TestBasicWalkSkipSymlink(t *testing.T) {
 	defer os.RemoveAll(dir)
 	testWalkSkipSymlink(t, FilesystemTypeBasic, dir)
 }
+
+func TestWalkTraverseDirJunct(t *testing.T) {
+	_, dir := setup(t)
+	defer os.RemoveAll(dir)
+	testWalkTraverseDirJunct(t, FilesystemTypeBasic, dir)
+}
+
+func TestWalkInfiniteRecursion(t *testing.T) {
+	_, dir := setup(t)
+	defer os.RemoveAll(dir)
+	testWalkInfiniteRecursion(t, FilesystemTypeBasic, dir)
+}

+ 1 - 1
lib/fs/lstat_regular.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at https://mozilla.org/MPL/2.0/.
 
-// +build !linux,!android
+// +build !linux,!android,!windows
 
 package fs
 

+ 80 - 0
lib/fs/lstat_windows.go

@@ -0,0 +1,80 @@
+// Copyright (C) 2015 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+// +build windows
+
+package fs
+
+import (
+	"fmt"
+	"os"
+	"syscall"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+)
+
+func isDirectoryJunction(path string) (bool, error) {
+	namep, err := syscall.UTF16PtrFromString(path)
+	if err != nil {
+		return false, fmt.Errorf("syscall.UTF16PtrFromString failed with: %s", err)
+	}
+	attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS | syscall.FILE_FLAG_OPEN_REPARSE_POINT)
+	h, err := syscall.CreateFile(namep, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0)
+	if err != nil {
+		return false, fmt.Errorf("syscall.CreateFile failed with: %s", err)
+	}
+	defer syscall.CloseHandle(h)
+
+	//https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_attribute_tag_info
+	const fileAttributeTagInfo = 9
+	type FILE_ATTRIBUTE_TAG_INFO struct {
+		FileAttributes uint32
+		ReparseTag     uint32
+	}
+
+	var ti FILE_ATTRIBUTE_TAG_INFO
+	err = windows.GetFileInformationByHandleEx(windows.Handle(h), fileAttributeTagInfo, (*byte)(unsafe.Pointer(&ti)), uint32(unsafe.Sizeof(ti)))
+	if err != nil {
+		if errno, ok := err.(syscall.Errno); ok && errno == windows.ERROR_INVALID_PARAMETER {
+			// It appears calling GetFileInformationByHandleEx with
+			// FILE_ATTRIBUTE_TAG_INFO fails on FAT file system with
+			// ERROR_INVALID_PARAMETER. Clear ti.ReparseTag in that
+			// instance to indicate no symlinks are possible.
+			ti.ReparseTag = 0
+		} else {
+			return false, fmt.Errorf("windows.GetFileInformationByHandleEx failed with: %s", err)
+		}
+	}
+	return ti.ReparseTag == windows.IO_REPARSE_TAG_MOUNT_POINT, nil
+}
+
+type dirJunctFileInfo struct {
+	os.FileInfo
+}
+
+func (fi *dirJunctFileInfo) Mode() os.FileMode {
+	return fi.FileInfo.Mode() ^ os.ModeSymlink | os.ModeDir
+}
+
+func (fi *dirJunctFileInfo) IsDir() bool {
+	return true
+}
+
+func underlyingLstat(name string) (os.FileInfo, error) {
+	var fi, err = os.Lstat(name)
+
+	// NTFS directory junctions are treated as ordinary directories,
+	// see https://forum.syncthing.net/t/option-to-follow-directory-junctions-symbolic-links/14750
+	if err == nil && fi.Mode()&os.ModeSymlink != 0 {
+		var isJunct bool
+		isJunct, err = isDirectoryJunction(name)
+		if err == nil && isJunct {
+			return &dirJunctFileInfo{fi}, nil
+		}
+	}
+	return fi, err
+}

+ 44 - 4
lib/fs/walkfs.go

@@ -10,7 +10,37 @@
 
 package fs
 
-import "path/filepath"
+import (
+	"path/filepath"
+)
+
+type ancestorDirList struct {
+	list []FileInfo
+	fs   Filesystem
+}
+
+func (ancestors *ancestorDirList) Push(info FileInfo) {
+	l.Debugf("ancestorDirList: Push '%s'", info.Name())
+	ancestors.list = append(ancestors.list, info)
+}
+
+func (ancestors *ancestorDirList) Pop() FileInfo {
+	aLen := len(ancestors.list)
+	info := ancestors.list[aLen-1]
+	l.Debugf("ancestorDirList: Pop '%s'", info.Name())
+	ancestors.list = ancestors.list[:aLen-1]
+	return info
+}
+
+func (ancestors *ancestorDirList) Contains(info FileInfo) bool {
+	l.Debugf("ancestorDirList: Contains '%s'", info.Name())
+	for _, ancestor := range ancestors.list {
+		if ancestors.fs.SameFile(info, ancestor) {
+			return true
+		}
+	}
+	return false
+}
 
 // WalkFunc is the type of the function called for each file or directory
 // visited by Walk. The path argument contains the argument to Walk as a
@@ -37,7 +67,8 @@ func NewWalkFilesystem(next Filesystem) Filesystem {
 }
 
 // walk recursively descends path, calling walkFn.
-func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error {
+func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc, ancestors *ancestorDirList) error {
+	l.Debugf("walk: path=%s", path)
 	path, err := Canonicalize(path)
 	if err != nil {
 		return err
@@ -55,6 +86,14 @@ func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error
 		return nil
 	}
 
+	if !ancestors.Contains(info) {
+		ancestors.Push(info)
+		defer ancestors.Pop()
+	} else {
+		l.Warnf("Infinite filesystem recursion detected on path '%s', not walking further down", path)
+		return nil
+	}
+
 	names, err := f.DirNames(path)
 	if err != nil {
 		return walkFn(path, info, err)
@@ -68,7 +107,7 @@ func (f *walkFilesystem) walk(path string, info FileInfo, walkFn WalkFunc) error
 				return err
 			}
 		} else {
-			err = f.walk(filename, fileInfo, walkFn)
+			err = f.walk(filename, fileInfo, walkFn, ancestors)
 			if err != nil {
 				if !fileInfo.IsDir() || err != SkipDir {
 					return err
@@ -90,5 +129,6 @@ func (f *walkFilesystem) Walk(root string, walkFn WalkFunc) error {
 	if err != nil {
 		return walkFn(root, nil, err)
 	}
-	return f.walk(root, info, walkFn)
+	ancestors := &ancestorDirList{fs: f.Filesystem}
+	return f.walk(root, info, walkFn, ancestors)
 }

+ 84 - 1
lib/fs/walkfs_test.go

@@ -7,13 +7,16 @@
 package fs
 
 import (
+	"fmt"
+	osexec "os/exec"
+	"path/filepath"
 	"runtime"
 	"testing"
 )
 
 func testWalkSkipSymlink(t *testing.T, fsType FilesystemType, uri string) {
 	if runtime.GOOS == "windows" {
-		t.Skip("Symlinks on windows")
+		t.Skip("Symlinks skipping is not tested on windows")
 	}
 
 	fs := NewFilesystem(fsType, uri)
@@ -39,3 +42,83 @@ func testWalkSkipSymlink(t *testing.T, fsType FilesystemType, uri string) {
 		t.Fatal(err)
 	}
 }
+
+func createDirJunct(target string, name string) error {
+	output, err := osexec.Command("cmd", "/c", "mklink", "/J", name, target).CombinedOutput()
+	if err != nil {
+		return fmt.Errorf("Failed to run mklink %v %v: %v %q", name, target, err, output)
+	}
+	return nil
+}
+
+func testWalkTraverseDirJunct(t *testing.T, fsType FilesystemType, uri string) {
+	if runtime.GOOS != "windows" {
+		t.Skip("Directory junctions are available and tested on windows only")
+	}
+
+	fs := NewFilesystem(fsType, uri)
+
+	if err := fs.MkdirAll("target/foo", 0); err != nil {
+		t.Fatal(err)
+	}
+	if err := fs.Mkdir("towalk", 0); err != nil {
+		t.Fatal(err)
+	}
+	if err := createDirJunct(filepath.Join(uri, "target"), filepath.Join(uri, "towalk/dirjunct")); err != nil {
+		t.Fatal(err)
+	}
+	traversed := false
+	if err := fs.Walk("towalk", func(path string, info FileInfo, err error) error {
+		if err != nil {
+			t.Fatal(err)
+		}
+		if info.Name() == "foo" {
+			traversed = true
+		}
+		return nil
+	}); err != nil {
+		t.Fatal(err)
+	}
+	if !traversed {
+		t.Fatal("Directory junction was not traversed")
+	}
+}
+
+func testWalkInfiniteRecursion(t *testing.T, fsType FilesystemType, uri string) {
+	if runtime.GOOS != "windows" {
+		t.Skip("Infinite recursion detection is tested on windows only")
+	}
+
+	fs := NewFilesystem(fsType, uri)
+
+	if err := fs.MkdirAll("target/foo", 0); err != nil {
+		t.Fatal(err)
+	}
+	if err := fs.Mkdir("towalk", 0); err != nil {
+		t.Fatal(err)
+	}
+	if err := createDirJunct(filepath.Join(uri, "target"), filepath.Join(uri, "towalk/dirjunct")); err != nil {
+		t.Fatal(err)
+	}
+	if err := createDirJunct(filepath.Join(uri, "towalk"), filepath.Join(uri, "target/foo/recurse")); err != nil {
+		t.Fatal(err)
+	}
+	dirjunctCnt := 0
+	fooCnt := 0
+	if err := fs.Walk("towalk", func(path string, info FileInfo, err error) error {
+		if err != nil {
+			t.Fatal(err)
+		}
+		if info.Name() == "dirjunct" {
+			dirjunctCnt++
+		} else if info.Name() == "foo" {
+			fooCnt++
+		}
+		return nil
+	}); err != nil {
+		t.Fatal(err)
+	}
+	if dirjunctCnt != 2 || fooCnt != 1 {
+		t.Fatal("Infinite recursion not detected correctly")
+	}
+}