Browse Source

control/controlclient: optimize zstd decode of KeepAlive messages

Maybe it matters? At least globally across all nodes?

Fixes #17343

Change-Id: I3f61758ea37de527e16602ec1a6e453d913b3195
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 months ago
parent
commit
2cd518a8b6
3 changed files with 57 additions and 2 deletions
  1. 16 2
      control/controlclient/direct.go
  2. 3 0
      control/controlclient/map.go
  3. 38 0
      control/controlclient/map_test.go

+ 16 - 2
control/controlclient/direct.go

@@ -1091,7 +1091,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
 		vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
 
 		var resp tailcfg.MapResponse
-		if err := c.decodeMsg(msg, &resp); err != nil {
+		if err := sess.decodeMsg(msg, &resp); err != nil {
 			vlogf("netmap: decode error: %v", err)
 			return err
 		}
@@ -1240,12 +1240,23 @@ func decode(res *http.Response, v any) error {
 
 var jsonEscapedZero = []byte(`\u0000`)
 
+const justKeepAliveStr = `{"KeepAlive":true}`
+
 // decodeMsg is responsible for uncompressing msg and unmarshaling into v.
-func (c *Direct) decodeMsg(compressedMsg []byte, v any) error {
+func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error {
+	// Fast path for common case of keep-alive message.
+	// See tailscale/tailscale#17343.
+	if sess.keepAliveZ != nil && bytes.Equal(compressedMsg, sess.keepAliveZ) {
+		v.KeepAlive = true
+		return nil
+	}
+
 	b, err := zstdframe.AppendDecode(nil, compressedMsg)
 	if err != nil {
 		return err
 	}
+	sess.ztdDecodesForTest++
+
 	if DevKnob.DumpNetMaps() {
 		var buf bytes.Buffer
 		json.Indent(&buf, b, "", "    ")
@@ -1258,6 +1269,9 @@ func (c *Direct) decodeMsg(compressedMsg []byte, v any) error {
 	if err := json.Unmarshal(b, v); err != nil {
 		return fmt.Errorf("response: %v", err)
 	}
+	if v.KeepAlive && string(b) == justKeepAliveStr {
+		sess.keepAliveZ = compressedMsg
+	}
 	return nil
 }
 

+ 3 - 0
control/controlclient/map.go

@@ -57,6 +57,9 @@ type mapSession struct {
 	altClock       tstime.Clock       // if nil, regular time is used
 	cancel         context.CancelFunc // always non-nil, shuts down caller's base long poll context
 
+	keepAliveZ        []byte // if non-nil, the learned zstd encoding of the just-KeepAlive message for this session
+	ztdDecodesForTest int    // for testing
+
 	// sessionAliveCtx is a Background-based context that's alive for the
 	// duration of the mapSession that we own the lifetime of. It's closed by
 	// sessionAliveCtxClose.

+ 38 - 0
control/controlclient/map_test.go

@@ -4,6 +4,7 @@
 package controlclient
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -33,6 +34,7 @@ import (
 	"tailscale.com/util/eventbus/eventbustest"
 	"tailscale.com/util/mak"
 	"tailscale.com/util/must"
+	"tailscale.com/util/zstdframe"
 )
 
 func eps(s ...string) []netip.AddrPort {
@@ -1445,3 +1447,39 @@ func TestNetmapForMapResponseForDebug(t *testing.T) {
 		t.Errorf("mismatch\nnm1: %s\nnm2: %s\n", logger.AsJSON(nm1), logger.AsJSON(nm2))
 	}
 }
+
+func TestLearnZstdOfKeepAlive(t *testing.T) {
+	keepAliveMsgZstd := (func() []byte {
+		msg := must.Get(json.Marshal(tailcfg.MapResponse{
+			KeepAlive: true,
+		}))
+		return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression)
+	})()
+
+	sess := newTestMapSession(t, nil)
+
+	// The first time we see a zstd keep-alive message, we learn how
+	// the server encodes that.
+	var mr tailcfg.MapResponse
+	must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr))
+	if !mr.KeepAlive {
+		t.Fatal("mr.KeepAlive false; want true")
+	}
+	if !bytes.Equal(sess.keepAliveZ, keepAliveMsgZstd) {
+		t.Fatalf("sess.keepAlive = %q; want %q", sess.keepAliveZ, keepAliveMsgZstd)
+	}
+	if got, want := sess.ztdDecodesForTest, 1; got != want {
+		t.Fatalf("got %d zstd decodes; want %d", got, want)
+	}
+
+	// The second time on the session where we see that message, we
+	// decode it without needing to decompress.
+	var mr2 tailcfg.MapResponse
+	must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr2))
+	if !mr2.KeepAlive {
+		t.Fatal("mr2.KeepAlive false; want true")
+	}
+	if got, want := sess.ztdDecodesForTest, 1; got != want {
+		t.Fatalf("got %d zstd decodes; want %d", got, want)
+	}
+}