Răsfoiți Sursa

tstest/natlab/vnet: add qemu + Virtualization.framework protocol tests

To test how virtual machines connect to the natlab vnet code.

Updates #13038

Change-Id: Ia4fd4b0c1803580ee7d94cc9878d777ad4f24f82
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 1 an în urmă
părinte
comite
8b23ba7d05
2 a modificat fișierele cu 189 adăugiri și 15 ștergeri
  1. 23 8
      tstest/natlab/vnet/vnet.go
  2. 166 7
      tstest/natlab/vnet/vnet_test.go

+ 23 - 8
tstest/natlab/vnet/vnet.go

@@ -559,6 +559,17 @@ func (n *network) unregisterWriter(mac MAC) {
 	n.writers.Delete(mac)
 	n.writers.Delete(mac)
 }
 }
 
 
+// RegisteredWritersForTest returns the number of registered connections (VM
+// guests with a known MAC to whom a packet can be sent) there are to the
+// server. It exists for testing.
+func (s *Server) RegisteredWritersForTest() int {
+	num := 0
+	for n := range s.networks {
+		num += n.writers.Len()
+	}
+	return num
+}
+
 func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) {
 func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) {
 	if n.lanIP4.Addr() == ip {
 	if n.lanIP4.Addr() == ip {
 		return n.mac, true
 		return n.mac, true
@@ -776,12 +787,12 @@ func (s *Server) writeEthernetFrameToVM(c vmClient, ethPkt []byte, interfaceID i
 		s.scratch = binary.BigEndian.AppendUint32(s.scratch[:0], uint32(len(ethPkt)))
 		s.scratch = binary.BigEndian.AppendUint32(s.scratch[:0], uint32(len(ethPkt)))
 		s.scratch = append(s.scratch, ethPkt...)
 		s.scratch = append(s.scratch, ethPkt...)
 		if _, err := c.uc.Write(s.scratch); err != nil {
 		if _, err := c.uc.Write(s.scratch); err != nil {
-			log.Printf("Write pkt: %v", err)
+			s.logf("Write pkt: %v", err)
 		}
 		}
 
 
 	case ProtocolUnixDGRAM:
 	case ProtocolUnixDGRAM:
 		if _, err := c.uc.WriteToUnix(ethPkt, c.raddr); err != nil {
 		if _, err := c.uc.WriteToUnix(ethPkt, c.raddr); err != nil {
-			log.Printf("Write pkt : %v", err)
+			s.logf("Write pkt : %v", err)
 			return
 			return
 		}
 		}
 	}
 	}
@@ -821,7 +832,7 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) {
 	context.AfterFunc(s.shutdownCtx, func() {
 	context.AfterFunc(s.shutdownCtx, func() {
 		uc.SetDeadline(time.Now())
 		uc.SetDeadline(time.Now())
 	})
 	})
-	log.Printf("Got conn %T %p", uc, uc)
+	s.logf("Got conn %T %p", uc, uc)
 	defer uc.Close()
 	defer uc.Close()
 
 
 	buf := make([]byte, 16<<10)
 	buf := make([]byte, 16<<10)
@@ -835,7 +846,11 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) {
 			n, addr, err := uc.ReadFromUnix(buf)
 			n, addr, err := uc.ReadFromUnix(buf)
 			raddr = addr
 			raddr = addr
 			if err != nil {
 			if err != nil {
-				log.Printf("ReadFromUnix: %v", err)
+				if s.shutdownCtx.Err() != nil {
+					// Return without logging.
+					return
+				}
+				s.logf("ReadFromUnix: %#v", err)
 				continue
 				continue
 			}
 			}
 			packetRaw = buf[:n]
 			packetRaw = buf[:n]
@@ -845,7 +860,7 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) {
 					// Return without logging.
 					// Return without logging.
 					return
 					return
 				}
 				}
-				log.Printf("ReadFull header: %v", err)
+				s.logf("ReadFull header: %v", err)
 				return
 				return
 			}
 			}
 			n := binary.BigEndian.Uint32(buf[:4])
 			n := binary.BigEndian.Uint32(buf[:4])
@@ -855,7 +870,7 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) {
 					// Return without logging.
 					// Return without logging.
 					return
 					return
 				}
 				}
-				log.Printf("ReadFull pkt: %v", err)
+				s.logf("ReadFull pkt: %v", err)
 				return
 				return
 			}
 			}
 			packetRaw = buf[4 : 4+n] // raw ethernet frame
 			packetRaw = buf[4 : 4+n] // raw ethernet frame
@@ -869,12 +884,12 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) {
 		srcMAC := MAC(packetRaw[6:12])
 		srcMAC := MAC(packetRaw[6:12])
 		srcNode, ok := s.nodeByMAC[srcMAC]
 		srcNode, ok := s.nodeByMAC[srcMAC]
 		if !ok {
 		if !ok {
-			log.Printf("[conn %p] got frame from unknown MAC %v", c.uc, srcMAC)
+			s.logf("[conn %p] got frame from unknown MAC %v", c.uc, srcMAC)
 			continue
 			continue
 		}
 		}
 		if !didReg[srcMAC] {
 		if !didReg[srcMAC] {
 			didReg[srcMAC] = true
 			didReg[srcMAC] = true
-			log.Printf("[conn %p] Registering writer for MAC %v, node %v", c.uc, srcMAC, srcNode.lanIP)
+			s.logf("[conn %p] Registering writer for MAC %v, node %v", c.uc, srcMAC, srcNode.lanIP)
 			srcNode.net.registerWriter(srcMAC, c)
 			srcNode.net.registerWriter(srcMAC, c)
 			defer srcNode.net.unregisterWriter(srcMAC)
 			defer srcNode.net.unregisterWriter(srcMAC)
 		}
 		}

+ 166 - 7
tstest/natlab/vnet/vnet_test.go

@@ -10,11 +10,15 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
+	"path/filepath"
+	"runtime"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
+	"time"
 
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
+	"tailscale.com/util/must"
 )
 )
 
 
 // TestPacketSideEffects tests that upon receiving certain
 // TestPacketSideEffects tests that upon receiving certain
@@ -32,13 +36,7 @@ func TestPacketSideEffects(t *testing.T) {
 	}{
 	}{
 		{
 		{
 			netName: "basic",
 			netName: "basic",
-			setup: func() (*Server, error) {
-				var c Config
-				nw := c.AddNetwork("192.168.0.1/24")
-				c.AddNode(nw)
-				c.AddNode(nw)
-				return New(&c)
-			},
+			setup:   newTwoNodesSameNetworkServer,
 			tests: []netTest{
 			tests: []netTest{
 				{
 				{
 					name: "drop-rando-ethertype",
 					name: "drop-rando-ethertype",
@@ -129,6 +127,14 @@ func mkEth(dst, src MAC, ethType layers.EthernetType, payload []byte) []byte {
 	return append(ret, payload...)
 	return append(ret, payload...)
 }
 }
 
 
+// mkLenPrefixed prepends a uint32 length to the given packet.
+func mkLenPrefixed(pkt []byte) []byte {
+	ret := make([]byte, 4+len(pkt))
+	binary.BigEndian.PutUint32(ret, uint32(len(pkt)))
+	copy(ret[4:], pkt)
+	return ret
+}
+
 // mkIPv6RouterSolicit makes a IPv6 router solicitation packet
 // mkIPv6RouterSolicit makes a IPv6 router solicitation packet
 // ethernet frame.
 // ethernet frame.
 func mkIPv6RouterSolicit(srcMAC MAC, srcIP netip.Addr) []byte {
 func mkIPv6RouterSolicit(srcMAC MAC, srcIP netip.Addr) []byte {
@@ -230,3 +236,156 @@ func numPkts(want int) func(*sideEffects) error {
 		return fmt.Errorf("got %d packets, want %d. packets were:\n%s", len(se.got), want, pkts.Bytes())
 		return fmt.Errorf("got %d packets, want %d. packets were:\n%s", len(se.got), want, pkts.Bytes())
 	}
 	}
 }
 }
+
+func newTwoNodesSameNetworkServer() (*Server, error) {
+	var c Config
+	nw := c.AddNetwork("192.168.0.1/24")
+	c.AddNode(nw)
+	c.AddNode(nw)
+	return New(&c)
+}
+
+// TestProtocolQEMU tests the protocol that qemu uses to connect to natlab's
+// vnet. (uint32-length prefixed ethernet frames over a unix stream socket)
+//
+// This test makes two clients (as qemu would act) and has one send an ethernet
+// packet to the other virtual LAN segment.
+func TestProtocolQEMU(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skipf("skipping on %s", runtime.GOOS)
+	}
+	s := must.Get(newTwoNodesSameNetworkServer())
+	defer s.Close()
+	s.SetLoggerForTest(t.Logf)
+
+	td := t.TempDir()
+	serverSock := filepath.Join(td, "vnet.sock")
+
+	ln, err := net.Listen("unix", serverSock)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	var clientc [2]*net.UnixConn
+	for i := range clientc {
+		c, err := net.Dial("unix", serverSock)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		clientc[i] = c.(*net.UnixConn)
+	}
+
+	for range clientc {
+		conn, err := ln.Accept()
+		if err != nil {
+			t.Fatal(err)
+		}
+		go s.ServeUnixConn(conn.(*net.UnixConn), ProtocolQEMU)
+	}
+
+	sendBetweenClients(t, clientc, s, mkLenPrefixed)
+}
+
+// TestProtocolUnixDgram tests the protocol that macOS Virtualization.framework
+// uses to connect to vnet. (unix datagram sockets)
+//
+// It is similar to TestProtocolQEMU but uses unix datagram sockets instead of
+// streams.
+func TestProtocolUnixDgram(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skipf("skipping on %s", runtime.GOOS)
+	}
+	s := must.Get(newTwoNodesSameNetworkServer())
+	defer s.Close()
+	s.SetLoggerForTest(t.Logf)
+
+	td := t.TempDir()
+	serverSock := filepath.Join(td, "vnet.sock")
+	serverAddr := must.Get(net.ResolveUnixAddr("unixgram", serverSock))
+
+	var clientSock [2]string
+	for i := range clientSock {
+		clientSock[i] = filepath.Join(td, fmt.Sprintf("c%d.sock", i))
+	}
+
+	uc, err := net.ListenUnixgram("unixgram", serverAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	go s.ServeUnixConn(uc, ProtocolUnixDGRAM)
+
+	var clientc [2]*net.UnixConn
+	for i := range clientc {
+		c, err := net.DialUnix("unixgram",
+			must.Get(net.ResolveUnixAddr("unixgram", clientSock[i])),
+			serverAddr)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		clientc[i] = c
+	}
+
+	sendBetweenClients(t, clientc, s, nil)
+}
+
+// sendBetweenClients is a test helper that tries to send an ethernet frame from
+// one client to another.
+//
+// It first makes the two clients send a packet to a fictitious node 3, which
+// forces their src MACs to be registered with a networkWriter internally so
+// they can receive traffic.
+//
+// Normally a node starts up spamming DHCP + NDP but we don't get that as a side
+// effect here, so this does it manually.
+//
+// It also then waits for them to be registered.
+//
+// wrap is an optional function that wraps the packet before sending it.
+func sendBetweenClients(t testing.TB, clientc [2]*net.UnixConn, s *Server, wrap func([]byte) []byte) {
+	t.Helper()
+	if wrap == nil {
+		wrap = func(b []byte) []byte { return b }
+	}
+	for i, c := range clientc {
+		must.Get(c.Write(wrap(mkEth(nodeMac(3), nodeMac(i+1), testingEthertype, []byte("hello")))))
+	}
+	awaitCond(t, 5*time.Second, func() error {
+		if n := s.RegisteredWritersForTest(); n != 2 {
+			return fmt.Errorf("got %d registered writers, want 2", n)
+		}
+		return nil
+	})
+
+	// Now see if node1 can write to node2 and node2 receives it.
+	pkt := wrap(mkEth(nodeMac(2), nodeMac(1), testingEthertype, []byte("test-msg")))
+	t.Logf("writing % 02x", pkt)
+	must.Get(clientc[0].Write(pkt))
+
+	buf := make([]byte, len(pkt))
+	clientc[1].SetReadDeadline(time.Now().Add(5 * time.Second))
+	n, err := clientc[1].Read(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	got := buf[:n]
+	if !bytes.Equal(got, pkt) {
+		t.Errorf("bad packet\n got: % 02x\nwant: % 02x", got, pkt)
+	}
+}
+
+func awaitCond(t testing.TB, timeout time.Duration, cond func() error) {
+	t.Helper()
+	t0 := time.Now()
+	for {
+		if err := cond(); err == nil {
+			return
+		}
+		if time.Since(t0) > timeout {
+			t.Fatalf("timed out after %v", timeout)
+		}
+		time.Sleep(10 * time.Millisecond)
+	}
+}