Kaynağa Gözat

derp: reduce DERP memory use; don't require callers to pass in memory to use

The magicsock derpReader was holding onto 65KB for each DERP
connection forever, just in case.

Make the derp{,http}.Client be in charge of memory instead. It can
reuse its bufio.Reader buffer space.
Brad Fitzpatrick 5 yıl önce
ebeveyn
işleme
abd79ea368

+ 1 - 2
cmd/derper/mesh.go

@@ -118,8 +118,7 @@ func runMeshClient(s *derp.Server, host string, c *derphttp.Client, logf logger.
 			return
 		}
 		for {
-			var buf [64 << 10]byte
-			m, connGen, err := c.RecvDetail(buf[:])
+			m, connGen, err := c.RecvDetail()
 			if err != nil {
 				clear()
 				logf("Recv: %v", err)

+ 47 - 8
derp/derp_client.go

@@ -30,8 +30,11 @@ type Client struct {
 	br           *bufio.Reader
 	meshKey      string
 
-	wmu     sync.Mutex // hold while writing to bw
-	bw      *bufio.Writer
+	wmu sync.Mutex // hold while writing to bw
+	bw  *bufio.Writer
+
+	// Owned by Recv:
+	peeked  int   // bytes to discard on next Recv
 	readErr error // sticky read error
 }
 
@@ -308,14 +311,16 @@ type PeerPresentMessage key.Public
 func (PeerPresentMessage) msg() {}
 
 // Recv reads a message from the DERP server.
-// The provided buffer must be large enough to receive a complete packet,
-// which in practice are are 1.5-4 KB, but can be up to 64 KB.
+//
+// The returned message may alias memory owned by the Client; it
+// should only be accessed until the next call to Client.
+//
 // Once Recv returns an error, the Client is dead forever.
-func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
-	return c.recvTimeout(b, 120*time.Second)
+func (c *Client) Recv() (m ReceivedMessage, err error) {
+	return c.recvTimeout(120 * time.Second)
 }
 
-func (c *Client) recvTimeout(b []byte, timeout time.Duration) (m ReceivedMessage, err error) {
+func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err error) {
 	if c.readErr != nil {
 		return nil, c.readErr
 	}
@@ -328,10 +333,44 @@ func (c *Client) recvTimeout(b []byte, timeout time.Duration) (m ReceivedMessage
 
 	for {
 		c.nc.SetReadDeadline(time.Now().Add(timeout))
-		t, n, err := readFrame(c.br, 1<<20, b)
+
+		// Discard any peeked bytes from a previous Recv call.
+		if c.peeked != 0 {
+			if n, err := c.br.Discard(c.peeked); err != nil || n != c.peeked {
+				// Documented to never fail, but might as well check.
+				return nil, fmt.Errorf("Discard(%d bytes): got %v, %v", c.peeked, n, err)
+			}
+			c.peeked = 0
+		}
+
+		t, n, err := readFrameHeader(c.br)
 		if err != nil {
 			return nil, err
 		}
+		if n > 1<<20 {
+			return nil, fmt.Errorf("unexpectedly large frame of %d bytes returned", n)
+		}
+
+		var b []byte // frame payload (past the 5 byte header)
+
+		// If the frame fits in our bufio.Reader buffer, just use it.
+		// In practice it's 4KB (from derphttp.Client's bufio.NewReader(httpConn)) and
+		// in practive, WireGuard packets (and thus DERP frames) are under 1.5KB.
+		// So This is the common path.
+		if int(n) <= c.br.Size() {
+			b, err = c.br.Peek(int(n))
+			c.peeked = int(n)
+		} else {
+			// But if for some reason we read a large DERP message (which isn't necessarily
+			// a Wireguard packet), then just allocate memory for it.
+			// TODO(bradfitz): use a pool if large frames ever happen in practice.
+			b = make([]byte, n)
+			_, err = io.ReadFull(c.br, b)
+		}
+		if err != nil {
+			return nil, err
+		}
+
 		switch t {
 		default:
 			continue

+ 5 - 9
derp/derp_test.go

@@ -90,8 +90,7 @@ func TestSendRecv(t *testing.T) {
 	for i := 0; i < numClients; i++ {
 		go func(i int) {
 			for {
-				b := make([]byte, 1<<16)
-				m, err := clients[i].Recv(b)
+				m, err := clients[i].Recv()
 				if err != nil {
 					errCh <- err
 					return
@@ -106,7 +105,7 @@ func TestSendRecv(t *testing.T) {
 					if m.Source.IsZero() {
 						t.Errorf("zero Source address in ReceivedPacket")
 					}
-					recvChs[i] <- m.Data
+					recvChs[i] <- append([]byte(nil), m.Data...)
 				}
 			}
 		}(i)
@@ -259,8 +258,7 @@ func TestSendFreeze(t *testing.T) {
 	recv := func(name string, client *Client) {
 		ch := chs(name)
 		for {
-			b := make([]byte, 1<<9)
-			m, err := client.Recv(b)
+			m, err := client.Recv()
 			if err != nil {
 				errCh <- fmt.Errorf("%s: %w", name, err)
 				return
@@ -529,9 +527,8 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
 		want[k] = true
 	}
 
-	var buf [64 << 10]byte
 	for {
-		m, err := tc.c.recvTimeout(buf[:], time.Second)
+		m, err := tc.c.recvTimeout(time.Second)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -557,8 +554,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
 
 func (tc *testClient) wantGone(t *testing.T, peer key.Public) {
 	t.Helper()
-	var buf [64 << 10]byte
-	m, err := tc.c.recvTimeout(buf[:], time.Second)
+	m, err := tc.c.recvTimeout(time.Second)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 6 - 6
derp/derphttp/derphttp_client.go

@@ -530,21 +530,21 @@ func (c *Client) WatchConnectionChanges() error {
 	return err
 }
 
-// Recv reads a message from c. The returned message may alias the provided buffer.
-// b should not be reused until the message is no longer used.
-func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
-	m, _, err := c.RecvDetail(b)
+// Recv reads a message from c. The returned message may alias memory from Client.
+// The message should only be used until the next Client call.
+func (c *Client) Recv() (derp.ReceivedMessage, error) {
+	m, _, err := c.RecvDetail()
 	return m, err
 }
 
 // RecvDetail is like Recv, but additional returns the connection generation on each message.
 // The connGen value is incremented every time the derphttp.Client reconnects to the server.
-func (c *Client) RecvDetail(b []byte) (m derp.ReceivedMessage, connGen int, err error) {
+func (c *Client) RecvDetail() (m derp.ReceivedMessage, connGen int, err error) {
 	client, connGen, err := c.connect(context.TODO(), "derphttp.Client.Recv")
 	if err != nil {
 		return nil, 0, err
 	}
-	m, err = client.Recv(b)
+	m, err = client.Recv()
 	if err != nil {
 		c.closeForReconnect(client)
 	}

+ 2 - 3
derp/derphttp/derphttp_test.go

@@ -93,8 +93,7 @@ func TestSendRecv(t *testing.T) {
 					return
 				default:
 				}
-				b := make([]byte, 1<<16)
-				m, err := c.Recv(b)
+				m, err := c.Recv()
 				if err != nil {
 					t.Logf("client%d: %v", i, err)
 					break
@@ -106,7 +105,7 @@ func TestSendRecv(t *testing.T) {
 				case derp.PeerGoneMessage:
 					// Ignore.
 				case derp.ReceivedPacket:
-					recvChs[i] <- m.Data
+					recvChs[i] <- append([]byte(nil), m.Data...)
 				}
 			}
 		}(i)

+ 1 - 2
wgengine/magicsock/magicsock.go

@@ -1000,7 +1000,6 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
 	}
 
 	didCopy := make(chan struct{}, 1)
-	var buf [derp.MaxPacketSize]byte
 
 	res := derpReadResult{derpAddr: derpFakeAddr}
 	var pkt derp.ReceivedPacket
@@ -1015,7 +1014,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
 	peerPresent := map[key.Public]bool{}
 
 	for {
-		msg, err := dc.Recv(buf[:])
+		msg, err := dc.Recv()
 		if err == derphttp.ErrClientClosed {
 			return
 		}