Browse Source

wgengine/magicsock: factor out receiveIPv4 & receiveIPv6 common code

Updates #2331

Change-Id: I801df38b217f5d17203e8dc3b8654f44747e0f4b
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 2 years ago
parent
commit
6866aaeab3
2 changed files with 55 additions and 73 deletions
  1. 49 70
      wgengine/magicsock/magicsock.go
  2. 6 3
      wgengine/magicsock/magicsock_test.go

+ 49 - 70
wgengine/magicsock/magicsock.go

@@ -322,11 +322,6 @@ type Conn struct {
 	// bind is the wireguard-go conn.Bind for Conn.
 	bind *connBind
 
-	// ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and
-	// receiveIPv6, respectively, to cache an IPPort->endpoint for
-	// hot flows.
-	ippEndpoint4, ippEndpoint6 ippEndpointCache
-
 	// ============================================================
 	// Fields that must be accessed via atomic load/stores.
 
@@ -1851,80 +1846,64 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) {
 	c.receiveBatchPool.Put(batch)
 }
 
-func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
-	health.ReceiveIPv6.Enter()
-	defer health.ReceiveIPv6.Exit()
+// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4.
+func (c *Conn) receiveIPv4() conn.ReceiveFunc {
+	return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4)
+}
 
-	batch := c.getReceiveBatchForBuffs(buffs)
-	defer c.putReceiveBatch(batch)
-	for {
-		numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0)
-		if err != nil {
-			if neterror.PacketWasTruncated(err) {
-				// TODO(raggi): discuss whether to log?
-				continue
-			}
-			return 0, err
-		}
+// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6.
+func (c *Conn) receiveIPv6() conn.ReceiveFunc {
+	return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6)
+}
 
-		reportToCaller := false
-		for i, msg := range batch.msgs[:numMsgs] {
-			if msg.N == 0 {
-				sizes[i] = 0
-				continue
-			}
-			ipp := msg.Addr.(*net.UDPAddr).AddrPort()
-			if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
-				metricRecvDataIPv6.Add(1)
-				eps[i] = ep
-				sizes[i] = msg.N
-				reportToCaller = true
-			} else {
-				sizes[i] = 0
-			}
-		}
+// mkReceiveFunc creates a ReceiveFunc reading from ruc.
+// The provided healthItem and metric are updated if non-nil.
+func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc {
+	// epCache caches an IPPort->endpoint for hot flows.
+	var epCache ippEndpointCache
 
-		if reportToCaller {
-			return numMsgs, nil
+	return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
+		if healthItem != nil {
+			healthItem.Enter()
+			defer healthItem.Exit()
+		}
+		if ruc == nil {
+			panic("nil RebindingUDPConn")
 		}
-	}
-}
-
-func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
-	health.ReceiveIPv4.Enter()
-	defer health.ReceiveIPv4.Exit()
 
-	batch := c.getReceiveBatchForBuffs(buffs)
-	defer c.putReceiveBatch(batch)
-	for {
-		numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0)
-		if err != nil {
-			if neterror.PacketWasTruncated(err) {
-				// TODO(raggi): discuss whether to log?
-				continue
+		batch := c.getReceiveBatchForBuffs(buffs)
+		defer c.putReceiveBatch(batch)
+		for {
+			numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0)
+			if err != nil {
+				if neterror.PacketWasTruncated(err) {
+					continue
+				}
+				return 0, err
 			}
-			return 0, err
-		}
 
-		reportToCaller := false
-		for i, msg := range batch.msgs[:numMsgs] {
-			if msg.N == 0 {
-				sizes[i] = 0
-				continue
+			reportToCaller := false
+			for i, msg := range batch.msgs[:numMsgs] {
+				if msg.N == 0 {
+					sizes[i] = 0
+					continue
+				}
+				ipp := msg.Addr.(*net.UDPAddr).AddrPort()
+				if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok {
+					if metric != nil {
+						metric.Add(1)
+					}
+					eps[i] = ep
+					sizes[i] = msg.N
+					reportToCaller = true
+				} else {
+					sizes[i] = 0
+				}
 			}
-			ipp := msg.Addr.(*net.UDPAddr).AddrPort()
-			if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok {
-				metricRecvDataIPv4.Add(1)
-				eps[i] = ep
-				sizes[i] = msg.N
-				reportToCaller = true
-			} else {
-				sizes[i] = 0
+			if reportToCaller {
+				return numMsgs, nil
 			}
 		}
-		if reportToCaller {
-			return numMsgs, nil
-		}
 	}
 }
 
@@ -3044,7 +3023,7 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
 		return nil, 0, errors.New("magicsock: connBind already open")
 	}
 	c.closed = false
-	fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveIPv6, c.receiveDERP}
+	fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
 	if runtime.GOOS == "js" {
 		fns = []conn.ReceiveFunc{c.receiveDERP}
 	}

+ 6 - 3
wgengine/magicsock/magicsock_test.go

@@ -374,8 +374,9 @@ func TestNewConn(t *testing.T) {
 		sizes := make([]int, 1)
 		eps := make([]wgconn.Endpoint, 1)
 		pkts[0] = make([]byte, 64<<10)
+		receiveIPv4 := conn.receiveIPv4()
 		for {
-			_, err := conn.receiveIPv4(pkts, sizes, eps)
+			_, err := receiveIPv4(pkts, sizes, eps)
 			if err != nil {
 				return
 			}
@@ -1284,11 +1285,12 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
 	buffs[0] = make([]byte, 2<<10)
 	sizes := make([]int, 1)
 	eps := make([]wgconn.Endpoint, 1)
+	receiveIPv4 := conn.receiveIPv4()
 	return func() {
 		if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
 			tb.Fatalf("WriteTo: %v", err)
 		}
-		n, err := conn.receiveIPv4(buffs, sizes, eps)
+		n, err := receiveIPv4(buffs, sizes, eps)
 		if err != nil {
 			tb.Fatal(err)
 		}
@@ -1513,8 +1515,9 @@ func TestRebindStress(t *testing.T) {
 		sizes := make([]int, 1)
 		eps := make([]wgconn.Endpoint, 1)
 		buffs[0] = make([]byte, 1500)
+		receiveIPv4 := conn.receiveIPv4()
 		for {
-			_, err := conn.receiveIPv4(buffs, sizes, eps)
+			_, err := receiveIPv4(buffs, sizes, eps)
 			if ctx.Err() != nil {
 				errc <- nil
 				return