|
|
@@ -4,6 +4,7 @@
|
|
|
package controlclient
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
@@ -33,6 +34,7 @@ import (
|
|
|
"tailscale.com/util/eventbus/eventbustest"
|
|
|
"tailscale.com/util/mak"
|
|
|
"tailscale.com/util/must"
|
|
|
+ "tailscale.com/util/zstdframe"
|
|
|
)
|
|
|
|
|
|
func eps(s ...string) []netip.AddrPort {
|
|
|
@@ -1445,3 +1447,39 @@ func TestNetmapForMapResponseForDebug(t *testing.T) {
|
|
|
t.Errorf("mismatch\nnm1: %s\nnm2: %s\n", logger.AsJSON(nm1), logger.AsJSON(nm2))
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestLearnZstdOfKeepAlive(t *testing.T) {
|
|
|
+ keepAliveMsgZstd := (func() []byte {
|
|
|
+ msg := must.Get(json.Marshal(tailcfg.MapResponse{
|
|
|
+ KeepAlive: true,
|
|
|
+ }))
|
|
|
+ return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression)
|
|
|
+ })()
|
|
|
+
|
|
|
+ sess := newTestMapSession(t, nil)
|
|
|
+
|
|
|
+ // The first time we see a zstd keep-alive message, we learn how
|
|
|
+ // the server encodes that.
|
|
|
+ var mr tailcfg.MapResponse
|
|
|
+ must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr))
|
|
|
+ if !mr.KeepAlive {
|
|
|
+ t.Fatal("mr.KeepAlive false; want true")
|
|
|
+ }
|
|
|
+ if !bytes.Equal(sess.keepAliveZ, keepAliveMsgZstd) {
|
|
|
+ t.Fatalf("sess.keepAlive = %q; want %q", sess.keepAliveZ, keepAliveMsgZstd)
|
|
|
+ }
|
|
|
+ if got, want := sess.ztdDecodesForTest, 1; got != want {
|
|
|
+ t.Fatalf("got %d zstd decodes; want %d", got, want)
|
|
|
+ }
|
|
|
+
|
|
|
+ // The second time on the session where we see that message, we
|
|
|
+ // decode it without needing to decompress.
|
|
|
+ var mr2 tailcfg.MapResponse
|
|
|
+ must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr2))
|
|
|
+ if !mr2.KeepAlive {
|
|
|
+ t.Fatal("mr2.KeepAlive false; want true")
|
|
|
+ }
|
|
|
+ if got, want := sess.ztdDecodesForTest, 1; got != want {
|
|
|
+ t.Fatalf("got %d zstd decodes; want %d", got, want)
|
|
|
+ }
|
|
|
+}
|