Browse Source

Support for multi proto tun device on NetBSD (#1492)

Nate Brown 1 week ago
parent
commit
eb89839d13
3 changed files with 406 additions and 87 deletions
  1. 24 0
      overlay/tun.go
  2. 0 24
      overlay/tun_freebsd.go
  3. 382 63
      overlay/tun_netbsd.go

+ 24 - 0
overlay/tun.go

@@ -81,3 +81,27 @@ func prefixToMask(prefix netip.Prefix) netip.Addr {
 	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
 	return addr
 }
+
+func flipBytes(b []byte) []byte {
+	for i := 0; i < len(b); i++ {
+		b[i] ^= 0xFF
+	}
+	return b
+}
+func orBytes(a []byte, b []byte) []byte {
+	ret := make([]byte, len(a))
+	for i := 0; i < len(a); i++ {
+		ret[i] = a[i] | b[i]
+	}
+	return ret
+}
+
+func getBroadcast(cidr netip.Prefix) netip.Addr {
+	broadcast, _ := netip.AddrFromSlice(
+		orBytes(
+			cidr.Addr().AsSlice(),
+			flipBytes(prefixToMask(cidr).AsSlice()),
+		),
+	)
+	return broadcast
+}

+ 0 - 24
overlay/tun_freebsd.go

@@ -501,30 +501,6 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
-func flipBytes(b []byte) []byte {
-	for i := 0; i < len(b); i++ {
-		b[i] ^= 0xFF
-	}
-	return b
-}
-func orBytes(a []byte, b []byte) []byte {
-	ret := make([]byte, len(a))
-	for i := 0; i < len(a); i++ {
-		ret[i] = a[i] | b[i]
-	}
-	return ret
-}
-
-func getBroadcast(cidr netip.Prefix) netip.Addr {
-	broadcast, _ := netip.AddrFromSlice(
-		orBytes(
-			cidr.Addr().AsSlice(),
-			flipBytes(prefixToMask(cidr).AsSlice()),
-		),
-	)
-	return broadcast
-}
-
 func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
 	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
 	if err != nil {

+ 382 - 63
overlay/tun_netbsd.go

@@ -4,13 +4,12 @@
 package overlay
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"net/netip"
 	"os"
-	"os/exec"
 	"regexp"
-	"strconv"
 	"sync/atomic"
 	"syscall"
 	"unsafe"
@@ -20,11 +19,42 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	netroute "golang.org/x/net/route"
+	"golang.org/x/sys/unix"
 )
 
-type ifreqDestroy struct {
-	Name [16]byte
-	pad  [16]byte
+const (
+	SIOCAIFADDR_IN6 = 0x8080696b
+	TUNSIFHEAD      = 0x80047442
+	TUNSIFMODE      = 0x80047458
+)
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+type ifreq struct {
+	Name [unix.IFNAMSIZ]byte
+	data int
+}
+
+type addrLifetime struct {
+	Expire    uint64
+	Preferred uint64
+	Vltime    uint32
+	Pltime    uint32
 }
 
 type tun struct {
@@ -34,40 +64,18 @@ type tun struct {
 	Routes      atomic.Pointer[[]Route]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
-
-	io.ReadWriteCloser
+	f           *os.File
+	fd          int
 }
 
-func (t *tun) Close() error {
-	if t.ReadWriteCloser != nil {
-		if err := t.ReadWriteCloser.Close(); err != nil {
-			return err
-		}
-
-		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
-		if err != nil {
-			return err
-		}
-		defer syscall.Close(s)
-
-		ifreq := ifreqDestroy{Name: t.deviceBytes()}
-
-		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
-
-		return err
-	}
-	return nil
-}
+var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 
-var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-
 func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
-	var file *os.File
 	var err error
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
@@ -77,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
 	}
 
-	file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+	fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
 	}
 
+	err = unix.SetNonblock(fd, true)
+	if err != nil {
+		l.WithError(err).Warn("Failed to set the tun device as nonblocking")
+	}
+
 	t := &tun{
-		ReadWriteCloser: file,
-		Device:          deviceName,
-		vpnNetworks:     vpnNetworks,
-		MTU:             c.GetInt("tun.mtu", DefaultMTU),
-		l:               l,
+		f:           os.NewFile(uintptr(fd), ""),
+		fd:          fd,
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
 	}
 
 	err = t.reload(c, true)
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 	return t, nil
 }
 
-func (t *tun) addIp(cidr netip.Prefix) error {
-	var err error
+func (t *tun) Close() error {
+	if t.f != nil {
+		if err := t.f.Close(); err != nil {
+			return fmt.Errorf("error closing tun file: %w", err)
+		}
 
-	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+		// t.f.Close should have handled it for us but let's be extra sure
+		_ = unix.Close(t.fd)
+
+		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		ifr := ifreq{Name: t.deviceBytes()}
+		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
+		return err
 	}
+	return nil
+}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'route add': %s", err)
+func (t *tun) Read(to []byte) (int, error) {
+	rc, err := t.f.SyscallConn()
+	if err != nil {
+		return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
 	}
 
-	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	var errno syscall.Errno
+	var n uintptr
+	err = rc.Read(func(fd uintptr) bool {
+		// first 4 bytes is protocol family, in network byte order
+		head := [4]byte{}
+		iovecs := []syscall.Iovec{
+			{&head[0], 4},
+			{&to[0], uint64(len(to))},
+		}
+
+		n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+		if errno.Temporary() {
+			// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
+			return false
+		}
+		return true
+	})
+	if err != nil {
+		if err == syscall.EBADF || err.Error() == "use of closed file" {
+			// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
+			// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
+			return 0, os.ErrClosed
+		}
+		return 0, fmt.Errorf("failed to make read call for tun: %w", err)
 	}
 
-	// Unsafe path routes
-	return t.addRoutes(false)
+	if errno != 0 {
+		return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
+	}
+
+	// fix bytes read number to exclude header
+	bytesRead := int(n)
+	if bytesRead < 0 {
+		return bytesRead, nil
+	} else if bytesRead < 4 {
+		return 0, nil
+	} else {
+		return bytesRead - 4, nil
+	}
+}
+
+// Write is only valid for single threaded use
+func (t *tun) Write(from []byte) (int, error) {
+	if len(from) <= 1 {
+		return 0, syscall.EIO
+	}
+
+	ipVer := from[0] >> 4
+	var head [4]byte
+	// first 4 bytes is protocol family, in network byte order
+	if ipVer == 4 {
+		head[3] = syscall.AF_INET
+	} else if ipVer == 6 {
+		head[3] = syscall.AF_INET6
+	} else {
+		return 0, fmt.Errorf("unable to determine IP version from packet")
+	}
+
+	rc, err := t.f.SyscallConn()
+	if err != nil {
+		return 0, err
+	}
+
+	var errno syscall.Errno
+	var n uintptr
+	err = rc.Write(func(fd uintptr) bool {
+		iovecs := []syscall.Iovec{
+			{&head[0], 4},
+			{&from[0], uint64(len(from))},
+		}
+
+		n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+		// According to NetBSD documentation for TUN, writes will only return errors in which
+		// this packet will never be delivered so just go on living life.
+		return true
+	})
+	if err != nil {
+		return 0, err
+	}
+
+	if errno != 0 {
+		return 0, errno
+	}
+
+	return int(n) - 4, err
+}
+
+func (t *tun) addIp(cidr netip.Prefix) error {
+	if cidr.Addr().Is4() {
+		var req ifreqAlias4
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.DstAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.MaskAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(cidr).As4(),
+		}
+
+		s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
+		}
+
+		return nil
+	}
+
+	if cidr.Addr().Is6() {
+		var req ifreqAlias6
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   cidr.Addr().As16(),
+		}
+		req.PrefixMask = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(cidr).As16(),
+		}
+		req.Lifetime = addrLifetime{
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		}
+
+		s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
+		}
+		return nil
+	}
+
+	return fmt.Errorf("unknown address type %v", cidr)
 }
 
 func (t *tun) Activate() error {
+	mode := int32(unix.IFF_BROADCAST)
+	err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
+	if err != nil {
+		return fmt.Errorf("failed to set tun device mode: %w", err)
+	}
+
+	v := 1
+	err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
+	if err != nil {
+		return fmt.Errorf("failed to set tun device head: %w", err)
+	}
+
+	err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
+	if err != nil {
+		return fmt.Errorf("failed to set tun mtu: %w", err)
+	}
+
 	for i := range t.vpnNetworks {
-		err := t.addIp(t.vpnNetworks[i])
+		err = t.addIp(t.vpnNetworks[i])
 		if err != nil {
 			return err
 		}
 	}
-	return nil
+
+	return t.addRoutes(false)
+}
+
+func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
+	s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(s)
+
+	ir := ifreq{Name: t.deviceBytes(), data: int(value)}
+	err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
+	return err
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
@@ -197,21 +396,23 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
+		err := addRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {
 				return retErr
 			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
 		}
 	}
 
@@ -224,10 +425,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		//TODO: CERT-V2 is this right?
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
+		err := delRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 			t.l.WithField("route", r).Info("Removed route")
@@ -242,3 +441,123 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	}
 	return
 }
+
+func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_ADD,
+		Flags:   unix.RTF_UP | unix.RTF_GATEWAY,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
+	} else {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		if errors.Is(err, unix.EEXIST) {
+			// Try to do a change
+			route.Type = unix.RTM_CHANGE
+			data, err = route.Marshal()
+			if err != nil {
+				return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
+			}
+			_, err = unix.Write(sock, data[:])
+			return err
+		}
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_DELETE,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
+	} else {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
+	for _, gateway := range gateways {
+		if dest.Addr().Is4() && gateway.Addr().Is4() {
+			return gateway, nil
+		}
+
+		if dest.Addr().Is6() && gateway.Addr().Is6() {
+			return gateway, nil
+		}
+	}
+
+	return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
+}