Forráskód Böngészése

net/netutil: only check Linux sysctls w/ procfs, assume absent means false

Fixes #7217

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 éve
szülő
commit
2477fc4952
2 módosított fájl, 27 hozzáadás és 13 törlés
  1. 13 13
      net/netutil/ip_forward.go
  2. 14 0
      net/netutil/netutil_test.go

+ 13 - 13
net/netutil/ip_forward.go

@@ -9,7 +9,6 @@ import (
 	"fmt"
 	"net/netip"
 	"os"
-	"os/exec"
 	"path/filepath"
 	"runtime"
 	"strconv"
@@ -194,27 +193,28 @@ const (
 // given interface.
 // The iface param determines which interface to check against, "" means to check
 // global config.
-// It tries to lookup the value directly from `/proc/sys`, and falls back to
-// using `sysctl` on failure.
+// This is Linux-specific: it only reads from /proc/sys and doesn't shell out to
+// sysctl (which on Linux just reads from /proc/sys anyway).
 func ipForwardingEnabledLinux(p protocol, iface string) (bool, error) {
 	k := ipForwardSysctlKey(slashFormat, p, iface)
 	bs, err := os.ReadFile(filepath.Join("/proc/sys", k))
 	if err != nil {
-		// Fallback to using sysctl.
-		// Sysctl accepts `/` as separator.
-		bs, err = exec.Command("sysctl", "-n", k).Output()
-		if err != nil {
-			// But in case it doesn't.
-			k := ipForwardSysctlKey(dotFormat, p, iface)
-			bs, err = exec.Command("sysctl", "-n", k).Output()
-			if err != nil {
-				return false, fmt.Errorf("couldn't check %s (%v)", k, err)
+		if os.IsNotExist(err) {
+			// If IPv6 is disabled, sysctl keys like "net.ipv6.conf.all.forwarding" just don't
+			// exist on disk. But first diagnose whether procfs is even mounted before assuming
+			// absence means false.
+			if fi, err := os.Stat("/proc/sys"); err != nil {
+				return false, fmt.Errorf("failed to check sysctl %v; no procfs? %w", k, err)
+			} else if !fi.IsDir() {
+				return false, fmt.Errorf("failed to check sysctl %v; /proc/sys isn't a directory, is %v", k, fi.Mode())
 			}
+			return false, nil
 		}
+		return false, err
 	}
 	on, err := strconv.ParseBool(string(bytes.TrimSpace(bs)))
 	if err != nil {
-		return false, fmt.Errorf("couldn't parse %s (%v)", k, err)
+		return false, fmt.Errorf("couldn't parse %s: %w", k, err)
 	}
 	return on, nil
 }

+ 14 - 0
net/netutil/netutil_test.go

@@ -6,6 +6,7 @@ package netutil
 import (
 	"io"
 	"net"
+	"runtime"
 	"testing"
 )
 
@@ -51,3 +52,16 @@ func TestOneConnListener(t *testing.T) {
 		t.Errorf("nil Addr")
 	}
 }
+
+func TestIPForwardingEnabledLinux(t *testing.T) {
+	if runtime.GOOS != "linux" {
+		t.Skipf("skipping on %s", runtime.GOOS)
+	}
+	got, err := ipForwardingEnabledLinux(ipv4, "some-not-found-interface")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got {
+		t.Errorf("got true; want false")
+	}
+}