Răsfoiți Sursa

net/netns: add Windows support for bind-to-interface-by-route

This is implemented via GetBestInterfaceEx. Should we encounter errors
or fail to resolve a valid, non-Tailscale interface, we fall back to
returning the index for the default interface instead.

Fixes #12551

Signed-off-by: Aaron Klotz <[email protected]>
Aaron Klotz 1 an în urmă
părinte
comite
7dd76c3411

+ 2 - 2
cmd/derper/depaware.txt

@@ -99,7 +99,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/net/netaddr                                    from tailscale.com/ipn+
         tailscale.com/net/netknob                                    from tailscale.com/net/netns
      💣 tailscale.com/net/netmon                                     from tailscale.com/derp/derphttp+
-        tailscale.com/net/netns                                      from tailscale.com/derp/derphttp
+     💣 tailscale.com/net/netns                                      from tailscale.com/derp/derphttp
         tailscale.com/net/netutil                                    from tailscale.com/client/tailscale
         tailscale.com/net/sockstats                                  from tailscale.com/derp/derphttp
         tailscale.com/net/stun                                       from tailscale.com/net/stunserver
@@ -114,7 +114,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/syncs                                          from tailscale.com/cmd/derper+
         tailscale.com/tailcfg                                        from tailscale.com/client/tailscale+
         tailscale.com/tka                                            from tailscale.com/client/tailscale+
-   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon
+   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon+
         tailscale.com/tstime                                         from tailscale.com/derp+
         tailscale.com/tstime/mono                                    from tailscale.com/tstime/rate
         tailscale.com/tstime/rate                                    from tailscale.com/derp

+ 2 - 2
cmd/tailscale/depaware.txt

@@ -103,7 +103,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/net/neterror                                   from tailscale.com/net/netcheck+
         tailscale.com/net/netknob                                    from tailscale.com/net/netns
      💣 tailscale.com/net/netmon                                     from tailscale.com/cmd/tailscale/cli+
-        tailscale.com/net/netns                                      from tailscale.com/derp/derphttp+
+     💣 tailscale.com/net/netns                                      from tailscale.com/derp/derphttp+
         tailscale.com/net/netutil                                    from tailscale.com/client/tailscale+
         tailscale.com/net/packet                                     from tailscale.com/wgengine/capture
         tailscale.com/net/ping                                       from tailscale.com/net/netcheck
@@ -121,7 +121,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/tailcfg                                        from tailscale.com/client/tailscale+
         tailscale.com/tempfork/spf13/cobra                           from tailscale.com/cmd/tailscale/cli/ffcomplete+
         tailscale.com/tka                                            from tailscale.com/client/tailscale+
-   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon
+   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon+
         tailscale.com/tstime                                         from tailscale.com/control/controlhttp+
         tailscale.com/tstime/mono                                    from tailscale.com/tstime/rate
         tailscale.com/tstime/rate                                    from tailscale.com/cmd/tailscale/cli+

+ 2 - 2
cmd/tailscaled/depaware.txt

@@ -303,7 +303,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/net/netkernelconf                              from tailscale.com/ipn/ipnlocal
         tailscale.com/net/netknob                                    from tailscale.com/logpolicy+
      💣 tailscale.com/net/netmon                                     from tailscale.com/cmd/tailscaled+
-        tailscale.com/net/netns                                      from tailscale.com/cmd/tailscaled+
+     💣 tailscale.com/net/netns                                      from tailscale.com/cmd/tailscaled+
    W 💣 tailscale.com/net/netstat                                    from tailscale.com/portlist
         tailscale.com/net/netutil                                    from tailscale.com/client/tailscale+
         tailscale.com/net/packet                                     from tailscale.com/net/connstats+
@@ -335,7 +335,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
   LD    tailscale.com/tempfork/gliderlabs/ssh                        from tailscale.com/ssh/tailssh
         tailscale.com/tempfork/heap                                  from tailscale.com/wgengine/magicsock
         tailscale.com/tka                                            from tailscale.com/client/tailscale+
-   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon
+   W    tailscale.com/tsconst                                        from tailscale.com/net/netmon+
         tailscale.com/tsd                                            from tailscale.com/cmd/tailscaled+
         tailscale.com/tstime                                         from tailscale.com/control/controlclient+
         tailscale.com/tstime/mono                                    from tailscale.com/net/tstun+

+ 9 - 0
net/netns/mksyscall.go

@@ -0,0 +1,9 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netns
+
+//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
+//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
+
+//sys getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) = iphlpapi.GetBestInterfaceEx

+ 1 - 1
net/netns/netns.go

@@ -38,7 +38,7 @@ var bindToInterfaceByRoute atomic.Bool
 // route information to bind to a particular interface. It is the same as
 // setting the TS_BIND_TO_INTERFACE_BY_ROUTE.
 //
-// Currently, this only changes the behaviour on macOS.
+// Currently, this only changes the behaviour on macOS and Windows.
 func SetBindToInterfaceByRoute(v bool) {
 	bindToInterfaceByRoute.Store(v)
 }

+ 2 - 8
net/netns/netns_darwin.go

@@ -89,16 +89,10 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
 		return defaultIdx()
 	}
 
-	host, _, err := net.SplitHostPort(address)
-	if err != nil {
-		// No port number; use the string directly.
-		host = address
-	}
-
 	// If the address doesn't parse, use the default index.
-	addr, err := netip.ParseAddr(host)
+	addr, err := parseAddress(address)
 	if err != nil {
-		logf("[unexpected] netns: error parsing address %q: %v", host, err)
+		logf("[unexpected] netns: error parsing address %q: %v", address, err)
 		return defaultIdx()
 	}
 

+ 21 - 0
net/netns/netns_dw.go

@@ -0,0 +1,21 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build darwin || windows
+
+package netns
+
+import (
+	"net"
+	"net/netip"
+)
+
+func parseAddress(address string) (addr netip.Addr, err error) {
+	host, _, err := net.SplitHostPort(address)
+	if err != nil {
+		// error means the string didn't contain a port number, so use the string directly
+		host = address
+	}
+
+	return netip.ParseAddr(host)
+}

+ 109 - 13
net/netns/netns_windows.go

@@ -4,14 +4,18 @@
 package netns
 
 import (
+	"fmt"
 	"math/bits"
+	"net/netip"
 	"strings"
 	"syscall"
 
 	"golang.org/x/sys/cpu"
 	"golang.org/x/sys/windows"
 	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+	"tailscale.com/envknob"
 	"tailscale.com/net/netmon"
+	"tailscale.com/tsconst"
 	"tailscale.com/types/logger"
 )
 
@@ -26,20 +30,34 @@ func interfaceIndex(iface *winipcfg.IPAdapterAddresses) uint32 {
 	return iface.IfIndex
 }
 
-func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
-	return controlC
+func defaultInterfaceIndex(family winipcfg.AddressFamily) (uint32, error) {
+	iface, err := netmon.GetWindowsDefault(family)
+	if err != nil {
+		return 0, err
+	}
+
+	return interfaceIndex(iface), nil
+}
+
+func control(logf logger.Logf, _ *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
+	return func(network, address string, c syscall.RawConn) error {
+		return controlC(logf, network, address, c)
+	}
 }
 
+var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE")
+
 // controlC binds c to the Windows interface that holds a default
 // route, and is not the Tailscale WinTun interface.
-func controlC(network, address string, c syscall.RawConn) error {
-	if strings.HasPrefix(address, "127.") {
+func controlC(logf logger.Logf, network, address string, c syscall.RawConn) (err error) {
+	if isLocalhost(address) {
 		// Don't bind to an interface for localhost connections,
 		// otherwise we get:
 		//   connectex: The requested address is not valid in its context
 		// (The derphttp tests were failing)
 		return nil
 	}
+
 	canV4, canV6 := false, false
 	switch network {
 	case "tcp", "udp":
@@ -50,29 +68,107 @@ func controlC(network, address string, c syscall.RawConn) error {
 		canV6 = true
 	}
 
+	var defIfaceIdxV4, defIfaceIdxV6 uint32
 	if canV4 {
-		iface, err := netmon.GetWindowsDefault(windows.AF_INET)
+		defIfaceIdxV4, err = defaultInterfaceIndex(windows.AF_INET)
 		if err != nil {
-			return err
-		}
-		if err := bindSocket4(c, interfaceIndex(iface)); err != nil {
-			return err
+			return fmt.Errorf("defaultInterfaceIndex(AF_INET): %w", err)
 		}
 	}
 
 	if canV6 {
-		iface, err := netmon.GetWindowsDefault(windows.AF_INET6)
+		defIfaceIdxV6, err = defaultInterfaceIndex(windows.AF_INET6)
 		if err != nil {
-			return err
+			return fmt.Errorf("defaultInterfaceIndex(AF_INET6): %w", err)
 		}
-		if err := bindSocket6(c, interfaceIndex(iface)); err != nil {
-			return err
+	}
+
+	var ifaceIdxV4, ifaceIdxV6 uint32
+	if useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv(); useRoute {
+		addr, err := parseAddress(address)
+		if err != nil {
+			return fmt.Errorf("parseAddress: %w", err)
+		}
+
+		if canV4 && (addr.Is4() || addr.Is4In6()) {
+			addrV4 := addr.Unmap()
+			ifaceIdxV4, err = getInterfaceIndex(logf, addrV4, defIfaceIdxV4)
+			if err != nil {
+				return fmt.Errorf("getInterfaceIndex(%v): %w", addrV4, err)
+			}
+		}
+
+		if canV6 && addr.Is6() {
+			ifaceIdxV6, err = getInterfaceIndex(logf, addr, defIfaceIdxV6)
+			if err != nil {
+				return fmt.Errorf("getInterfaceIndex(%v): %w", addr, err)
+			}
+		}
+	} else {
+		ifaceIdxV4, ifaceIdxV6 = defIfaceIdxV4, defIfaceIdxV6
+	}
+
+	if canV4 {
+		if err := bindSocket4(c, ifaceIdxV4); err != nil {
+			return fmt.Errorf("bindSocket4(%d): %w", ifaceIdxV4, err)
+		}
+	}
+
+	if canV6 {
+		if err := bindSocket6(c, ifaceIdxV6); err != nil {
+			return fmt.Errorf("bindSocket6(%d): %w", ifaceIdxV6, err)
 		}
 	}
 
 	return nil
 }
 
+func getInterfaceIndex(logf logger.Logf, addr netip.Addr, defaultIdx uint32) (idx uint32, err error) {
+	idx, err = interfaceIndexFor(addr)
+	if err != nil {
+		return defaultIdx, fmt.Errorf("interfaceIndexFor: %w", err)
+	}
+
+	isTS, err := isTailscaleInterface(idx)
+	if err != nil {
+		return defaultIdx, fmt.Errorf("isTailscaleInterface: %w", err)
+	}
+	if isTS {
+		return defaultIdx, nil
+	}
+	return idx, nil
+}
+
+func isTailscaleInterface(ifaceIdx uint32) (bool, error) {
+	ifaceLUID, err := winipcfg.LUIDFromIndex(ifaceIdx)
+	if err != nil {
+		return false, err
+	}
+
+	iface, err := ifaceLUID.Interface()
+	if err != nil {
+		return false, err
+	}
+
+	result := iface.Type == winipcfg.IfTypePropVirtual &&
+		strings.Contains(iface.Description(), tsconst.WintunInterfaceDesc)
+	return result, nil
+}
+
+func interfaceIndexFor(addr netip.Addr) (uint32, error) {
+	var sockaddr winipcfg.RawSockaddrInet
+	if err := sockaddr.SetAddr(addr); err != nil {
+		return 0, err
+	}
+
+	var idx uint32
+	if err := getBestInterfaceEx(&sockaddr, &idx); err != nil {
+		return 0, err
+	}
+
+	return idx, nil
+}
+
 // sockoptBoundInterface is the value of IP_UNICAST_IF and IPV6_UNICAST_IF.
 //
 // See https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options

+ 112 - 0
net/netns/netns_windows_test.go

@@ -0,0 +1,112 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package netns
+
+import (
+	"strings"
+	"testing"
+
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+	"tailscale.com/tsconst"
+)
+
+func TestGetInterfaceIndex(t *testing.T) {
+	oldVal := bindToInterfaceByRoute.Load()
+	t.Cleanup(func() { bindToInterfaceByRoute.Store(oldVal) })
+	bindToInterfaceByRoute.Store(true)
+
+	defIfaceIdxV4, err := defaultInterfaceIndex(windows.AF_INET)
+	if err != nil {
+		t.Fatalf("defaultInterfaceIndex(AF_INET) failed: %v", err)
+	}
+
+	tests := []struct {
+		name string
+		addr string
+		err  string
+	}{
+		{
+			name: "IP_and_port",
+			addr: "8.8.8.8:53",
+		},
+		{
+			name: "bare_ip",
+			addr: "8.8.8.8",
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			addr, err := parseAddress(tc.addr)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			idx, err := getInterfaceIndex(t.Logf, addr, defIfaceIdxV4)
+			if err != nil {
+				if tc.err == "" {
+					t.Fatalf("got unexpected error: %v", err)
+				}
+				if errstr := err.Error(); errstr != tc.err {
+					t.Errorf("expected error %q, got %q", errstr, tc.err)
+				}
+			} else {
+				t.Logf("getInterfaceIndex(%q) = %d", tc.addr, idx)
+				if tc.err != "" {
+					t.Fatalf("wanted error %q", tc.err)
+				}
+			}
+		})
+	}
+
+	t.Run("NoTailscale", func(t *testing.T) {
+		tsIdx, ok, err := tailscaleInterfaceIndex()
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !ok {
+			t.Skip("no tailscale interface on this machine")
+		}
+
+		defaultIdx, err := defaultInterfaceIndex(windows.AF_INET)
+		if err != nil {
+			t.Fatalf("defaultInterfaceIndex(AF_INET) failed: %v", err)
+		}
+
+		addr, err := parseAddress("100.100.100.100:53")
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		idx, err := getInterfaceIndex(t.Logf, addr, defaultIdx)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		t.Logf("tailscaleIdx=%d defaultIdx=%d idx=%d", tsIdx, defaultIdx, idx)
+
+		if idx == tsIdx {
+			t.Fatalf("got idx=%d; wanted not Tailscale interface", idx)
+		} else if idx != defaultIdx {
+			t.Fatalf("got idx=%d, want %d", idx, defaultIdx)
+		}
+	})
+}
+
+func tailscaleInterfaceIndex() (idx uint32, found bool, err error) {
+	ifs, err := winipcfg.GetAdaptersAddresses(windows.AF_INET, winipcfg.GAAFlagIncludeAllInterfaces)
+	if err != nil {
+		return idx, false, err
+	}
+
+	for _, iface := range ifs {
+		if iface.IfType != winipcfg.IfTypePropVirtual {
+			continue
+		}
+		if strings.Contains(iface.Description(), tsconst.WintunInterfaceDesc) {
+			return iface.IfIndex, true, nil
+		}
+	}
+	return idx, false, nil
+}

+ 53 - 0
net/netns/zsyscall_windows.go

@@ -0,0 +1,53 @@
+// Code generated by 'go generate'; DO NOT EDIT.
+
+package netns
+
+import (
+	"syscall"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+	errnoERROR_IO_PENDING = 997
+)
+
+var (
+	errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+	errERROR_EINVAL     error = syscall.EINVAL
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+	switch e {
+	case 0:
+		return errERROR_EINVAL
+	case errnoERROR_IO_PENDING:
+		return errERROR_IO_PENDING
+	}
+	// TODO: add more here, after collecting data on the common
+	// error values see on Windows. (perhaps when running
+	// all.bat?)
+	return e
+}
+
+var (
+	modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
+
+	procGetBestInterfaceEx = modiphlpapi.NewProc("GetBestInterfaceEx")
+)
+
+func getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) {
+	r0, _, _ := syscall.Syscall(procGetBestInterfaceEx.Addr(), 2, uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)), 0)
+	if r0 != 0 {
+		ret = syscall.Errno(r0)
+	}
+	return
+}