Nate Brown 3 mēneši atpakaļ
vecāks
revīzija
4870bb680d
4 mainītis faili ar 177 papildinājumiem un 11 dzēšanām
  1. 5 0
      udp/errors.go
  2. 169 9
      udp/udp_darwin.go
  3. 2 1
      udp/udp_generic.go
  4. 1 1
      udp/udp_linux.go

+ 5 - 0
udp/errors.go

@@ -0,0 +1,5 @@
+package udp
+
+import "errors"
+
+var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

+ 169 - 9
udp/udp_darwin.go

@@ -6,17 +6,63 @@ package udp
 // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
 
 import (
+	"context"
+	"encoding/binary"
+	"errors"
 	"fmt"
 	"net"
 	"net/netip"
 	"syscall"
+	"unsafe"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 	"golang.org/x/sys/unix"
 )
 
+type StdConn struct {
+	*net.UDPConn
+	isV4  bool
+	sysFd uintptr
+	l     *logrus.Logger
+}
+
+var _ Conn = &StdConn{}
+
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
-	return NewGenericListener(l, ip, port, multi, batch)
+	lc := NewListenConfig(multi)
+	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
+	if err != nil {
+		return nil, err
+	}
+
+	if uc, ok := pc.(*net.UDPConn); ok {
+		c := &StdConn{UDPConn: uc, l: l}
+
+		rc, err := uc.SyscallConn()
+		if err != nil {
+			return nil, fmt.Errorf("failed to open udp socket: %w", err)
+		}
+
+		err = rc.Control(func(fd uintptr) {
+			c.sysFd = fd
+		})
+		if err != nil {
+			return nil, fmt.Errorf("failed to get udp fd: %w", err)
+		}
+
+		la, err := c.LocalAddr()
+		if err != nil {
+			return nil, err
+		}
+		c.isV4 = la.Addr().Is4()
+
+		return c, nil
+	}
+
+	return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
 }
 
 func NewListenConfig(multi bool) net.ListenConfig {
@@ -43,16 +89,130 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *GenericConn) Rebind() error {
-	rc, err := u.UDPConn.SyscallConn()
-	if err != nil {
-		return err
+//go:linkname sendto golang.org/x/sys/unix.sendto
+//go:noescape
+func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
+
+func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
+	var sa unsafe.Pointer
+	var addrLen int32
+
+	if u.isV4 {
+		if ap.Addr().Is6() {
+			return ErrInvalidIPv6RemoteForSocket
+		}
+
+		var rsa unix.RawSockaddrInet6
+		rsa.Family = unix.AF_INET6
+		rsa.Addr = ap.Addr().As16()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
+		sa = unsafe.Pointer(&rsa)
+		addrLen = syscall.SizeofSockaddrInet4
+	} else {
+		var rsa unix.RawSockaddrInet6
+		rsa.Family = unix.AF_INET6
+		rsa.Addr = ap.Addr().As16()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
+		sa = unsafe.Pointer(&rsa)
+		addrLen = syscall.SizeofSockaddrInet6
 	}
 
-	return rc.Control(func(fd uintptr) {
-		err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+	// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
+	// See https://github.com/golang/go/issues/73919
+	for {
+		//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
+		err := sendto(int(u.sysFd), b, 0, sa, addrLen)
+		if err == nil {
+			// Written, get out before the error handling
+			return nil
+		}
+
+		if errors.Is(err, syscall.EINTR) {
+			// Write was interrupted, retry
+			continue
+		}
+
+		if errors.Is(err, syscall.EAGAIN) {
+			return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
+		}
+
+		if errors.Is(err, syscall.EBADF) {
+			return net.ErrClosed
+		}
+
+		return &net.OpError{Op: "sendto", Err: err}
+	}
+}
+
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
+	a := u.UDPConn.LocalAddr()
+
+	switch v := a.(type) {
+	case *net.UDPAddr:
+		addr, ok := netip.AddrFromSlice(v.IP)
+		if !ok {
+			return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+		}
+		return netip.AddrPortFrom(addr, uint16(v.Port)), nil
+
+	default:
+		return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
+	}
+}
+
+func (u *StdConn) ReloadConfig(c *config.C) {
+	// TODO
+}
+
+func NewUDPStatsEmitter(udpConns []Conn) func() {
+	// No UDP stats for non-linux
+	return func() {}
+}
+
+func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	buffer := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	for {
+		// Just read one packet at a time
+		n, rua, err := u.ReadFromUDPAddrPort(buffer)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to rebind udp socket")
+			if errors.Is(err, net.ErrClosed) {
+				u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+				return
+			}
+
+			u.l.WithError(err).Error("unexpected udp socket receive error")
 		}
-	})
+
+		r(
+			netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
+			plaintext[:0],
+			buffer[:n],
+			h,
+			fwPacket,
+			lhf,
+			nb,
+			q,
+			cache.Get(u.l),
+		)
+	}
+}
+
+func (u *StdConn) Rebind() error {
+	var err error
+	if u.isV4 {
+		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
+	} else {
+		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
+	}
+
+	if err != nil {
+		u.l.WithError(err).Error("Failed to rebind udp socket")
+	}
+
+	return nil
 }

+ 2 - 1
udp/udp_generic.go

@@ -1,6 +1,7 @@
-//go:build (!linux || android) && !e2e_testing
+//go:build (!linux || android) && !e2e_testing && !darwin
 // +build !linux android
 // +build !e2e_testing
+// +build !darwin
 
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.

+ 1 - 1
udp/udp_linux.go

@@ -243,7 +243,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 
 func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
 	if !ip.Addr().Is4() {
-		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		return ErrInvalidIPv6RemoteForSocket
 	}
 
 	var rsa unix.RawSockaddrInet4