Browse Source

util/dirwalk, metrics, portlist: add new package for fast directory walking

This is similar to the golang.org/x/tools/internal/fastwalk I'd
previously written but not recursive and using mem.RO.

The metrics package already had some Linux-specific directory reading
code in it. Move that out to a new general package that can be reused
by portlist too, which helps its scanning of all /proc files:

    name                old time/op    new time/op    delta
    FindProcessNames-8    2.79ms ± 6%    2.45ms ± 7%  -12.11%  (p=0.000 n=10+10)

    name                old alloc/op   new alloc/op   delta
    FindProcessNames-8    62.9kB ± 0%    33.5kB ± 0%  -46.76%  (p=0.000 n=9+10)

    name                old allocs/op  new allocs/op  delta
    FindProcessNames-8     2.25k ± 0%     0.38k ± 0%  -82.98%  (p=0.000 n=9+10)

Change-Id: I75db393032c328f12d95c39f71c9742c375f207a
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 years ago
parent
commit
db2cc393af

+ 2 - 1
cmd/derper/depaware.txt

@@ -34,7 +34,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/hostinfo                                       from tailscale.com/net/interfaces+
         tailscale.com/ipn                                            from tailscale.com/client/tailscale
         tailscale.com/ipn/ipnstate                                   from tailscale.com/client/tailscale+
-     💣 tailscale.com/metrics                                        from tailscale.com/cmd/derper+
+        tailscale.com/metrics                                        from tailscale.com/cmd/derper+
         tailscale.com/net/dnscache                                   from tailscale.com/derp/derphttp
         tailscale.com/net/flowtrack                                  from tailscale.com/net/packet+
      💣 tailscale.com/net/interfaces                                 from tailscale.com/net/netns+
@@ -72,6 +72,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
    W    tailscale.com/util/clientmetric                              from tailscale.com/net/tshttpproxy
         tailscale.com/util/cloudenv                                  from tailscale.com/hostinfo+
    W    tailscale.com/util/cmpver                                    from tailscale.com/net/tshttpproxy
+   L 💣 tailscale.com/util/dirwalk                                   from tailscale.com/metrics
         tailscale.com/util/dnsname                                   from tailscale.com/hostinfo+
    W    tailscale.com/util/endian                                    from tailscale.com/net/netns
         tailscale.com/util/lineread                                  from tailscale.com/hostinfo+

+ 2 - 1
cmd/tailscale/depaware.txt

@@ -52,7 +52,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/hostinfo                                       from tailscale.com/net/interfaces+
         tailscale.com/ipn                                            from tailscale.com/cmd/tailscale/cli+
         tailscale.com/ipn/ipnstate                                   from tailscale.com/cmd/tailscale/cli+
-     💣 tailscale.com/metrics                                        from tailscale.com/derp
+        tailscale.com/metrics                                        from tailscale.com/derp
         tailscale.com/net/dnscache                                   from tailscale.com/derp/derphttp+
         tailscale.com/net/dnsfallback                                from tailscale.com/control/controlhttp
         tailscale.com/net/flowtrack                                  from tailscale.com/wgengine/filter+
@@ -95,6 +95,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/clientmetric                              from tailscale.com/net/netcheck+
         tailscale.com/util/cloudenv                                  from tailscale.com/net/dnscache+
    W    tailscale.com/util/cmpver                                    from tailscale.com/net/tshttpproxy
+   L 💣 tailscale.com/util/dirwalk                                   from tailscale.com/metrics
         tailscale.com/util/dnsname                                   from tailscale.com/cmd/tailscale/cli+
    W    tailscale.com/util/endian                                    from tailscale.com/net/netns
         tailscale.com/util/groupmember                               from tailscale.com/cmd/tailscale/cli

+ 2 - 1
cmd/tailscaled/depaware.txt

@@ -212,7 +212,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/logtail                                        from tailscale.com/control/controlclient+
         tailscale.com/logtail/backoff                                from tailscale.com/control/controlclient+
         tailscale.com/logtail/filch                                  from tailscale.com/logpolicy
-     💣 tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/metrics                                        from tailscale.com/derp+
         tailscale.com/net/dns                                        from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/dns/publicdns                              from tailscale.com/net/dns/resolver+
         tailscale.com/net/dns/resolvconffile                         from tailscale.com/net/dns+
@@ -275,6 +275,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/util/cloudenv                                  from tailscale.com/net/dns/resolver+
   LW    tailscale.com/util/cmpver                                    from tailscale.com/net/dns+
      💣 tailscale.com/util/deephash                                  from tailscale.com/ipn/ipnlocal+
+   L 💣 tailscale.com/util/dirwalk                                   from tailscale.com/metrics+
         tailscale.com/util/dnsname                                   from tailscale.com/hostinfo+
   LW    tailscale.com/util/endian                                    from tailscale.com/net/dns+
         tailscale.com/util/goroutines                                from tailscale.com/control/controlclient+

+ 25 - 92
metrics/fds_linux.go

@@ -5,105 +5,38 @@
 package metrics
 
 import (
-	"fmt"
-	"log"
-	"syscall"
-	"unsafe"
+	"io/fs"
+	"sync"
 
-	"golang.org/x/sys/unix"
+	"go4.org/mem"
+	"tailscale.com/util/dirwalk"
 )
 
-func currentFDs() int {
-	fd, err := openProcSelfFD()
-	if err != nil {
-		return 0
-	}
-	defer syscall.Close(fd)
-
-	count := 0
+// counter is a reusable counter for counting file descriptors.
+type counter struct {
+	n int
 
-	const blockSize = 8 << 10
-	buf := make([]byte, blockSize) // stack-allocated; doesn't escape
-	bufp := 0                      // starting read position in buf
-	nbuf := 0                      // end valid data in buf
-	dirent := &syscall.Dirent{}
-	for {
-		if bufp >= nbuf {
-			bufp = 0
-			nbuf, err = readDirent(fd, buf)
-			if err != nil {
-				log.Printf("currentFDs: readDirent: %v", err)
-				return 0
-			}
-			if nbuf <= 0 {
-				return count
-			}
-		}
-		consumed, name := parseDirEnt(dirent, buf[bufp:nbuf])
-		bufp += consumed
-		if len(name) == 0 || string(name) == "." || string(name) == ".." {
-			continue
-		}
-		count++
-	}
+	// cb is the (*counter).count method value. Creating it allocates,
+	// so we have to save it away and use a sync.Pool to keep currentFDs
+	// amortized alloc-free.
+	cb func(name mem.RO, de fs.DirEntry) error
 }
 
-func direntNamlen(dirent *syscall.Dirent) int {
-	const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name))
-	limit := dirent.Reclen - fixedHdr
-	const dirNameLen = 256 // sizeof syscall.Dirent.Name
-	if limit > dirNameLen {
-		limit = dirNameLen
-	}
-	for i := uint16(0); i < limit; i++ {
-		if dirent.Name[i] == 0 {
-			return int(i)
-		}
-	}
-	panic("failed to find terminating 0 byte in dirent")
-}
+var counterPool = &sync.Pool{New: func() any {
+	c := new(counter)
+	c.cb = c.count
+	return c
+}}
 
-func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) {
-	// golang.org/issue/37269
-	copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf)
-	if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v {
-		panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v))
-	}
-	if len(buf) < int(dirent.Reclen) {
-		panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen))
-	}
-	consumed = int(dirent.Reclen)
-	if dirent.Ino == 0 { // File absent in directory.
-		return
-	}
-	name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent))
-	return
+func (c *counter) count(name mem.RO, de fs.DirEntry) error {
+	c.n++
+	return nil
 }
 
-var procSelfFDName = []byte("/proc/self/fd\x00")
-
-func openProcSelfFD() (fd int, err error) {
-	var dirfd int = unix.AT_FDCWD
-	for {
-		r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd),
-			uintptr(unsafe.Pointer(&procSelfFDName[0])), 0)
-		if e1 == 0 {
-			return int(r0), nil
-		}
-		if e1 == syscall.EINTR {
-			// Since https://golang.org/doc/go1.14#runtime we
-			// need to loop on EINTR on more places.
-			continue
-		}
-		return 0, syscall.Errno(e1)
-	}
-}
-
-func readDirent(fd int, buf []byte) (n int, err error) {
-	for {
-		nbuf, err := syscall.ReadDirent(fd, buf)
-		if err != syscall.EINTR {
-			return nbuf, err
-		}
-	}
+func currentFDs() int {
+	c := counterPool.Get().(*counter)
+	defer counterPool.Put(c)
+	c.n = 0
+	dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb)
+	return c.n
 }

+ 60 - 80
portlist/portlist_linux.go

@@ -10,11 +10,11 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"log"
 	"os"
 	"path/filepath"
 	"runtime"
-	"strconv"
 	"strings"
 	"syscall"
 	"time"
@@ -22,6 +22,7 @@ import (
 
 	"go4.org/mem"
 	"golang.org/x/sys/unix"
+	"tailscale.com/util/dirwalk"
 	"tailscale.com/util/mak"
 )
 
@@ -32,7 +33,8 @@ func init() {
 }
 
 type linuxImpl struct {
-	procNetFiles []*os.File // seeked to start & reused between calls
+	procNetFiles    []*os.File // seeked to start & reused between calls
+	readlinkPathBuf []byte
 
 	known map[string]*portMeta // inode string => metadata
 	br    *bufio.Reader
@@ -270,71 +272,59 @@ func (li *linuxImpl) findProcessNames(need map[string]*portMeta) error {
 		}
 	}()
 
-	var pathBuf []byte
-
-	err := foreachPID(func(pid string) error {
-		fdPath := fmt.Sprintf("/proc/%s/fd", pid)
+	err := foreachPID(func(pid mem.RO) error {
+		var procBuf [128]byte
+		fdPath := mem.Append(procBuf[:0], mem.S("/proc/"))
+		fdPath = mem.Append(fdPath, pid)
+		fdPath = mem.Append(fdPath, mem.S("/fd"))
 
 		// Android logs a bunch of audit violations in logcat
 		// if we try to open things we don't have access
 		// to. So on Android only, ask if we have permission
 		// rather than just trying it to determine whether we
 		// have permission.
-		if runtime.GOOS == "android" && syscall.Access(fdPath, unix.R_OK) != nil {
-			return nil
-		}
-
-		fdDir, err := os.Open(fdPath)
-		if err != nil {
-			// Can't open fd list for this pid. Maybe
-			// don't have access. Ignore it.
+		if runtime.GOOS == "android" && syscall.Access(string(fdPath), unix.R_OK) != nil {
 			return nil
 		}
-		defer fdDir.Close()
 
-		targetBuf := make([]byte, 64) // plenty big for "socket:[165614651]"
-		for {
-			fds, err := fdDir.Readdirnames(100)
-			if err == io.EOF {
+		dirwalk.WalkShallow(mem.B(fdPath), func(fd mem.RO, de fs.DirEntry) error {
+			targetBuf := make([]byte, 64) // plenty big for "socket:[165614651]"
+
+			linkPath := li.readlinkPathBuf[:0]
+			linkPath = fmt.Appendf(linkPath, "/proc/")
+			linkPath = mem.Append(linkPath, pid)
+			linkPath = append(linkPath, "/fd/"...)
+			linkPath = mem.Append(linkPath, fd)
+			linkPath = append(linkPath, 0) // terminating NUL
+			li.readlinkPathBuf = linkPath  // to reuse its buffer next time
+			n, ok := readlink(linkPath, targetBuf)
+			if !ok {
+				// Not a symlink or no permission.
+				// Skip it.
 				return nil
 			}
-			if os.IsNotExist(err) {
-				// This can happen if the directory we're
-				// reading disappears during the run. No big
-				// deal.
+
+			pe := need[string(targetBuf[:n])] // m[string([]byte)] avoids alloc
+			if pe == nil {
 				return nil
 			}
+			bs, err := os.ReadFile(fmt.Sprintf("/proc/%s/cmdline", pid.StringCopy()))
 			if err != nil {
-				return fmt.Errorf("addProcesses.readDir: %w", err)
+				// Usually shouldn't happen. One possibility is
+				// the process has gone away, so let's skip it.
+				return nil
 			}
-			for _, fd := range fds {
-				pathBuf = fmt.Appendf(pathBuf[:0], "/proc/%s/fd/%s\x00", pid, fd)
-				n, ok := readlink(pathBuf, targetBuf)
-				if !ok {
-					// Not a symlink or no permission.
-					// Skip it.
-					continue
-				}
-
-				pe := need[string(targetBuf[:n])] // m[string([]byte)] avoids alloc
-				if pe != nil {
-					bs, err := os.ReadFile(fmt.Sprintf("/proc/%s/cmdline", pid))
-					if err != nil {
-						// Usually shouldn't happen. One possibility is
-						// the process has gone away, so let's skip it.
-						continue
-					}
-
-					argv := strings.Split(strings.TrimSuffix(string(bs), "\x00"), "\x00")
-					pe.port.Process = argvSubject(argv...)
-					pe.needsProcName = false
-					delete(need, string(targetBuf[:n]))
-					if len(need) == 0 {
-						return errDone
-					}
-				}
+
+			argv := strings.Split(strings.TrimSuffix(string(bs), "\x00"), "\x00")
+			pe.port.Process = argvSubject(argv...)
+			pe.needsProcName = false
+			delete(need, string(targetBuf[:n]))
+			if len(need) == 0 {
+				return errDone
 			}
-		}
+			return nil
+		})
+		return nil
 	})
 	if err == errDone {
 		return nil
@@ -342,40 +332,30 @@ func (li *linuxImpl) findProcessNames(need map[string]*portMeta) error {
 	return err
 }
 
-func foreachPID(fn func(pidStr string) error) error {
-	pdir, err := os.Open("/proc")
-	if err != nil {
-		return err
-	}
-	defer pdir.Close()
-
-	for {
-		pids, err := pdir.Readdirnames(100)
-		if err == io.EOF {
-			return nil
-		}
-		if os.IsNotExist(err) {
-			// This can happen if the directory we're
-			// reading disappears during the run. No big
-			// deal.
+func foreachPID(fn func(pidStr mem.RO) error) error {
+	err := dirwalk.WalkShallow(mem.S("/proc"), func(name mem.RO, de fs.DirEntry) error {
+		if !isNumeric(name) {
 			return nil
 		}
-		if err != nil {
-			return fmt.Errorf("foreachPID.readdir: %w", err)
-		}
+		return fn(name)
+	})
+	if os.IsNotExist(err) {
+		// This can happen if the directory we're
+		// reading disappears during the run. No big
+		// deal.
+		return nil
+	}
+	return err
+}
 
-		for _, pid := range pids {
-			_, err := strconv.ParseInt(pid, 10, 64)
-			if err != nil {
-				// not a pid, ignore it.
-				// /proc has lots of non-pid stuff in it.
-				continue
-			}
-			if err := fn(pid); err != nil {
-				return err
-			}
+func isNumeric(s mem.RO) bool {
+	for i, n := 0, s.Len(); i < n; i++ {
+		b := s.At(i)
+		if b < '0' || b > '9' {
+			return false
 		}
 	}
+	return s.Len() > 0
 }
 
 // fieldIndex returns the offset in line where the Nth field (0-based) begins, or -1

+ 13 - 0
portlist/portlist_linux_test.go

@@ -136,3 +136,16 @@ func BenchmarkParsePorts(b *testing.B) {
 		}
 	}
 }
+
+func BenchmarkFindProcessNames(b *testing.B) {
+	b.ReportAllocs()
+	li := &linuxImpl{}
+	need := map[string]*portMeta{
+		"something-we'll-never-find": new(portMeta),
+	}
+	for i := 0; i < b.N; i++ {
+		if err := li.findProcessNames(need); err != nil {
+			b.Fatal(err)
+		}
+	}
+}

+ 54 - 0
util/dirwalk/dirwalk.go

@@ -0,0 +1,54 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dirwalk contains code to walk a directory.
+package dirwalk
+
+import (
+	"io"
+	"io/fs"
+	"os"
+
+	"go4.org/mem"
+)
+
+var osWalkShallow func(name mem.RO, fn WalkFunc) error
+
+// WalkFunc is the callback type used with WalkShallow.
+//
+// The name and de are only valid for the duration of func's call
+// and should not be retained.
+type WalkFunc func(name mem.RO, de fs.DirEntry) error
+
+// WalkShallow reads the entries in the named directory and calls fn for each.
+// It does not recurse into subdirectories.
+//
+// If fn returns an error, iteration stops and WalkShallow returns that value.
+//
+// On Linux, WalkShallow does not allocate, so long as certain methods on the
+// WalkFunc's DirEntry are not called which necessarily allocate.
+func WalkShallow(dirName mem.RO, fn WalkFunc) error {
+	if f := osWalkShallow; f != nil {
+		return f(dirName, fn)
+	}
+	of, err := os.Open(dirName.StringCopy())
+	if err != nil {
+		return err
+	}
+	defer of.Close()
+	for {
+		fis, err := of.ReadDir(100)
+		for _, de := range fis {
+			if err := fn(mem.S(de.Name()), de); err != nil {
+				return err
+			}
+		}
+		if err != nil {
+			if err == io.EOF {
+				return nil
+			}
+			return err
+		}
+	}
+}

+ 168 - 0
util/dirwalk/dirwalk_linux.go

@@ -0,0 +1,168 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dirwalk
+
+import (
+	"fmt"
+	"io/fs"
+	"os"
+	"path/filepath"
+	"sync"
+	"syscall"
+	"unsafe"
+
+	"go4.org/mem"
+	"golang.org/x/sys/unix"
+)
+
+func init() {
+	osWalkShallow = linuxWalkShallow
+}
+
+var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }}
+
+func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error {
+	const blockSize = 8 << 10
+	buf := make([]byte, blockSize) // stack-allocated; doesn't escape
+
+	nameb := mem.Append(buf[:0], dirName)
+	nameb = append(nameb, 0)
+
+	fd, err := sysOpen(nameb)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(fd)
+
+	bufp := 0 // starting read position in buf
+	nbuf := 0 // end valid data in buf
+
+	de := dirEntPool.Get().(*linuxDirEnt)
+	defer de.cleanAndPutInPool()
+	de.root = dirName
+
+	for {
+		if bufp >= nbuf {
+			bufp = 0
+			nbuf, err = readDirent(fd, buf)
+			if err != nil {
+				return err
+			}
+			if nbuf <= 0 {
+				return nil
+			}
+		}
+		consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf])
+		bufp += consumed
+		if len(name) == 0 || string(name) == "." || string(name) == ".." {
+			continue
+		}
+		de.name = mem.B(name)
+		if err := fn(de.name, de); err != nil {
+			return err
+		}
+	}
+}
+
+type linuxDirEnt struct {
+	root mem.RO
+	d    syscall.Dirent
+	name mem.RO
+}
+
+func (de *linuxDirEnt) cleanAndPutInPool() {
+	de.root = mem.RO{}
+	de.name = mem.RO{}
+	dirEntPool.Put(de)
+}
+
+func (de *linuxDirEnt) Name() string { return de.name.StringCopy() }
+func (de *linuxDirEnt) Info() (fs.FileInfo, error) {
+	return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy()))
+}
+func (de *linuxDirEnt) IsDir() bool {
+	return de.d.Type == syscall.DT_DIR
+}
+func (de *linuxDirEnt) Type() fs.FileMode {
+	switch de.d.Type {
+	case syscall.DT_BLK:
+		return fs.ModeDevice // shrug
+	case syscall.DT_CHR:
+		return fs.ModeCharDevice
+	case syscall.DT_DIR:
+		return fs.ModeDir
+	case syscall.DT_FIFO:
+		return fs.ModeNamedPipe
+	case syscall.DT_LNK:
+		return fs.ModeSymlink
+	case syscall.DT_REG:
+		return 0
+	case syscall.DT_SOCK:
+		return fs.ModeSocket
+	default:
+		return fs.ModeIrregular // shrug
+	}
+}
+
+func direntNamlen(dirent *syscall.Dirent) int {
+	const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name))
+	limit := dirent.Reclen - fixedHdr
+	const dirNameLen = 256 // sizeof syscall.Dirent.Name
+	if limit > dirNameLen {
+		limit = dirNameLen
+	}
+	for i := uint16(0); i < limit; i++ {
+		if dirent.Name[i] == 0 {
+			return int(i)
+		}
+	}
+	panic("failed to find terminating 0 byte in dirent")
+}
+
+func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) {
+	// golang.org/issue/37269
+	copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf)
+	if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v {
+		panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v))
+	}
+	if len(buf) < int(dirent.Reclen) {
+		panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen))
+	}
+	consumed = int(dirent.Reclen)
+	if dirent.Ino == 0 { // File absent in directory.
+		return
+	}
+	name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent))
+	return
+}
+
+func sysOpen(name []byte) (fd int, err error) {
+	if len(name) == 0 || name[len(name)-1] != 0 {
+		return 0, syscall.EINVAL
+	}
+	var dirfd int = unix.AT_FDCWD
+	for {
+		r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd),
+			uintptr(unsafe.Pointer(&name[0])), 0)
+		if e1 == 0 {
+			return int(r0), nil
+		}
+		if e1 == syscall.EINTR {
+			// Since https://golang.org/doc/go1.14#runtime we
+			// need to loop on EINTR on more places.
+			continue
+		}
+		return 0, syscall.Errno(e1)
+	}
+}
+
+func readDirent(fd int, buf []byte) (n int, err error) {
+	for {
+		nbuf, err := syscall.ReadDirent(fd, buf)
+		if err != syscall.EINTR {
+			return nbuf, err
+		}
+	}
+}

+ 93 - 0
util/dirwalk/dirwalk_test.go

@@ -0,0 +1,93 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dirwalk
+
+import (
+	"fmt"
+	"os"
+	"path/filepath"
+	"reflect"
+	"runtime"
+	"sort"
+	"testing"
+
+	"go4.org/mem"
+)
+
+func TestWalkShallowOSSpecific(t *testing.T) {
+	if osWalkShallow == nil {
+		t.Skip("no OS-specific implementation")
+	}
+	testWalkShallow(t, false)
+}
+
+func TestWalkShallowPortable(t *testing.T) {
+	testWalkShallow(t, true)
+}
+
+func testWalkShallow(t *testing.T, portable bool) {
+	if portable {
+		old := osWalkShallow
+		defer func() { osWalkShallow = old }()
+		osWalkShallow = nil
+	}
+	d := t.TempDir()
+
+	t.Run("basics", func(t *testing.T) {
+		if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil {
+			t.Fatal(err)
+		}
+		if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil {
+			t.Fatal(err)
+		}
+		if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil {
+			t.Fatal(err)
+		}
+
+		var got []string
+		if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error {
+			var size int64
+			if fi, err := de.Info(); err != nil {
+				t.Errorf("Info stat error on %q: %v", de.Name(), err)
+			} else if !fi.IsDir() {
+				size = fi.Size()
+			}
+			got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size))
+			return nil
+		}); err != nil {
+			t.Fatal(err)
+		}
+		sort.Strings(got)
+		want := []string{
+			`"bar" "bar" dir=false type=0 size=2`,
+			`"baz" "baz" dir=true type=2147483648 size=0`,
+			`"foo" "foo" dir=false type=0 size=1`,
+		}
+		if !reflect.DeepEqual(got, want) {
+			t.Errorf("mismatch:\n got %#q\nwant %#q", got, want)
+		}
+	})
+
+	t.Run("err_not_exist", func(t *testing.T) {
+		err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error {
+			return nil
+		})
+		if !os.IsNotExist(err) {
+			t.Errorf("unexpected error: %v", err)
+		}
+	})
+
+	t.Run("allocs", func(t *testing.T) {
+		allocs := int(testing.AllocsPerRun(1000, func() {
+			if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil {
+				t.Fatal(err)
+			}
+		}))
+		t.Logf("allocs = %v", allocs)
+		if !portable && runtime.GOOS == "linux" && allocs != 0 {
+			t.Errorf("unexpected allocs: got %v, want 0", allocs)
+		}
+	})
+}