Browse Source

all: use zstdframe where sensible (#11491)

Use the zstdframe package where sensible instead of plumbing
around our own zstd.Encoder just for stateless operations.

This causes logtail to have a dependency on zstd,
but that's arguably okay since zstd support is implicit
to the protocol between a client and the logging service.
Also, virtually every caller to logger.NewLogger was
manually setting up a zstd.Encoder anyways,
meaning that zstd was functionally always a dependency.

Updates #cleanup
Updates tailscale/corp#18514

Signed-off-by: Joe Tsai <[email protected]>
Joe Tsai 1 year ago
parent
commit
85febda86d

+ 2 - 2
cmd/tailscaled/depaware.txt

@@ -119,7 +119,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         github.com/klauspost/compress/huff0                          from github.com/klauspost/compress/zstd
         github.com/klauspost/compress/internal/cpuinfo               from github.com/klauspost/compress/huff0+
         github.com/klauspost/compress/internal/snapref               from github.com/klauspost/compress/zstd
-        github.com/klauspost/compress/zstd                           from tailscale.com/smallzstd
+        github.com/klauspost/compress/zstd                           from tailscale.com/util/zstdframe
         github.com/klauspost/compress/zstd/internal/xxhash           from github.com/klauspost/compress/zstd
         github.com/kortschak/wol                                     from tailscale.com/ipn/ipnlocal
   LD    github.com/kr/fs                                             from github.com/pkg/sftp
@@ -315,7 +315,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/posture                                        from tailscale.com/ipn/ipnlocal
         tailscale.com/proxymap                                       from tailscale.com/tsd+
      💣 tailscale.com/safesocket                                     from tailscale.com/client/tailscale+
-        tailscale.com/smallzstd                                      from tailscale.com/control/controlclient+
   LD 💣 tailscale.com/ssh/tailssh                                    from tailscale.com/cmd/tailscaled
         tailscale.com/syncs                                          from tailscale.com/cmd/tailscaled+
         tailscale.com/tailcfg                                        from tailscale.com/client/tailscale+
@@ -393,6 +392,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
      💣 tailscale.com/util/winutil                                   from tailscale.com/clientupdate+
    W 💣 tailscale.com/util/winutil/authenticode                      from tailscale.com/clientupdate+
    W    tailscale.com/util/winutil/policy                            from tailscale.com/ipn/ipnlocal
+        tailscale.com/util/zstdframe                                 from tailscale.com/control/controlclient+
         tailscale.com/version                                        from tailscale.com/client/web+
         tailscale.com/version/distro                                 from tailscale.com/client/web+
    W    tailscale.com/wf                                             from tailscale.com/cmd/tailscaled

+ 2 - 12
control/controlclient/direct.go

@@ -42,7 +42,6 @@ import (
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tsdial"
 	"tailscale.com/net/tshttpproxy"
-	"tailscale.com/smallzstd"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tka"
 	"tailscale.com/tstime"
@@ -57,6 +56,7 @@ import (
 	"tailscale.com/util/singleflight"
 	"tailscale.com/util/syspolicy"
 	"tailscale.com/util/systemd"
+	"tailscale.com/util/zstdframe"
 )
 
 // Direct is the client that connects to a tailcontrol server for a node.
@@ -180,11 +180,6 @@ type Pinger interface {
 	Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error)
 }
 
-type Decompressor interface {
-	DecodeAll(input, dst []byte) ([]byte, error)
-	Close()
-}
-
 // NetmapUpdater is the interface needed by the controlclient to enact change in
 // the world as a function of updates received from the network.
 type NetmapUpdater interface {
@@ -1208,12 +1203,7 @@ func (c *Direct) decodeMsg(msg []byte, v any, mkey key.MachinePrivate) error {
 	} else {
 		decrypted = msg
 	}
-	decoder, err := smallzstd.NewDecoder(nil)
-	if err != nil {
-		return err
-	}
-	defer decoder.Close()
-	b, err := decoder.DecodeAll(decrypted, nil)
+	b, err := zstdframe.AppendDecode(nil, decrypted)
 	if err != nil {
 		return err
 	}

+ 5 - 12
log/sockstatlog/logger.go

@@ -22,7 +22,6 @@ import (
 	"tailscale.com/logtail/filch"
 	"tailscale.com/net/netmon"
 	"tailscale.com/net/sockstats"
-	"tailscale.com/smallzstd"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/logid"
 	"tailscale.com/util/mak"
@@ -117,17 +116,11 @@ func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *ne
 		tr:    logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, logf),
 	}
 	logger.logger = logtail.NewLogger(logtail.Config{
-		BaseURL:    logpolicy.LogURL(),
-		PrivateID:  SockstatLogID(logID),
-		Collection: "sockstats.log.tailscale.io",
-		Buffer:     filch,
-		NewZstdEncoder: func() logtail.Encoder {
-			w, err := smallzstd.NewEncoder(nil)
-			if err != nil {
-				panic(err)
-			}
-			return w
-		},
+		BaseURL:      logpolicy.LogURL(),
+		PrivateID:    SockstatLogID(logID),
+		Collection:   "sockstats.log.tailscale.io",
+		Buffer:       filch,
+		CompressLogs: true,
 		FlushDelayFn: func() time.Duration {
 			// set flush delay to 100 years so it never flushes automatically
 			return 100 * 365 * 24 * time.Hour

+ 5 - 12
logpolicy/logpolicy.go

@@ -42,7 +42,6 @@ import (
 	"tailscale.com/net/tshttpproxy"
 	"tailscale.com/paths"
 	"tailscale.com/safesocket"
-	"tailscale.com/smallzstd"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/logid"
 	"tailscale.com/util/clientmetric"
@@ -551,17 +550,11 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor,
 	}
 
 	conf := logtail.Config{
-		Collection: newc.Collection,
-		PrivateID:  newc.PrivateID,
-		Stderr:     logWriter{console},
-		NewZstdEncoder: func() logtail.Encoder {
-			w, err := smallzstd.NewEncoder(nil)
-			if err != nil {
-				panic(err)
-			}
-			return w
-		},
-		HTTPC: &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon, logf)},
+		Collection:   newc.Collection,
+		PrivateID:    newc.PrivateID,
+		Stderr:       logWriter{console},
+		CompressLogs: true,
+		HTTPC:        &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon, logf)},
 	}
 	if collection == logtail.CollectionNode {
 		conf.MetricsDelta = clientmetric.EncodeLogTailMetricsDelta

+ 19 - 3
logtail/logtail.go

@@ -31,6 +31,7 @@ import (
 	tslogger "tailscale.com/types/logger"
 	"tailscale.com/types/logid"
 	"tailscale.com/util/set"
+	"tailscale.com/util/zstdframe"
 )
 
 // DefaultHost is the default host name to upload logs to when
@@ -62,7 +63,10 @@ type Config struct {
 	Stderr         io.Writer       // if set, logs are sent here instead of os.Stderr
 	StderrLevel    int             // max verbosity level to write to stderr; 0 means the non-verbose messages only
 	Buffer         Buffer          // temp storage, if nil a MemoryBuffer
-	NewZstdEncoder func() Encoder  // if set, used to compress logs for transmission
+	CompressLogs   bool            // whether to compress the log uploads
+
+	// Deprecated: Use CompressUploads instead.
+	NewZstdEncoder func() Encoder // if set, used to compress logs for transmission
 
 	// MetricsDelta, if non-nil, is a func that returns an encoding
 	// delta in clientmetrics to upload alongside existing logs.
@@ -156,6 +160,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger {
 		shutdownDone:  make(chan struct{}),
 	}
 	l.SetSockstatsLabel(sockstats.LabelLogtailLogger)
+	l.compressLogs = cfg.CompressLogs
 	if cfg.NewZstdEncoder != nil {
 		l.zstdEncoder = cfg.NewZstdEncoder()
 	}
@@ -184,6 +189,7 @@ type Logger struct {
 	flushPending   atomic.Bool
 	sentinel       chan int32
 	clock          tstime.Clock
+	compressLogs   bool
 	zstdEncoder    Encoder
 	uploadCancel   func()
 	explainedRaw   bool
@@ -364,8 +370,18 @@ func (l *Logger) uploading(ctx context.Context) {
 		body := l.drainPending(scratch)
 		origlen := -1 // sentinel value: uncompressed
 		// Don't attempt to compress tiny bodies; not worth the CPU cycles.
-		if l.zstdEncoder != nil && len(body) > 256 {
-			zbody := l.zstdEncoder.EncodeAll(body, nil)
+		if (l.compressLogs || l.zstdEncoder != nil) && len(body) > 256 {
+			var zbody []byte
+			switch {
+			case l.zstdEncoder != nil:
+				zbody = l.zstdEncoder.EncodeAll(body, nil)
+			case l.lowMem:
+				zbody = zstdframe.AppendEncode(nil, body,
+					zstdframe.FastestCompression, zstdframe.LowMemory(true))
+			default:
+				zbody = zstdframe.AppendEncode(nil, body)
+			}
+
 			// Only send it compressed if the bandwidth savings are sufficient.
 			// Just the extra headers associated with enabling compression
 			// are 50 bytes by themselves.

+ 5 - 12
tsnet/tsnet.go

@@ -46,7 +46,6 @@ import (
 	"tailscale.com/net/proxymux"
 	"tailscale.com/net/socks5"
 	"tailscale.com/net/tsdial"
-	"tailscale.com/smallzstd"
 	"tailscale.com/tsd"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/logid"
@@ -651,17 +650,11 @@ func (s *Server) startLogger(closePool *closeOnErrorPool) error {
 	}
 	closePool.add(s.logbuffer)
 	c := logtail.Config{
-		Collection: lpc.Collection,
-		PrivateID:  lpc.PrivateID,
-		Stderr:     io.Discard, // log everything to Buffer
-		Buffer:     s.logbuffer,
-		NewZstdEncoder: func() logtail.Encoder {
-			w, err := smallzstd.NewEncoder(nil)
-			if err != nil {
-				panic(err)
-			}
-			return w
-		},
+		Collection:   lpc.Collection,
+		PrivateID:    lpc.PrivateID,
+		Stderr:       io.Discard, // log everything to Buffer
+		Buffer:       s.logbuffer,
+		CompressLogs: true,
 		HTTPC:        &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, s.netMon, s.logf)},
 		MetricsDelta: clientmetric.EncodeLogTailMetricsDelta,
 	}

+ 8 - 10
tstest/integration/integration.go

@@ -27,17 +27,16 @@ import (
 	"testing"
 	"time"
 
-	"github.com/klauspost/compress/zstd"
 	"go4.org/mem"
 	"tailscale.com/derp"
 	"tailscale.com/derp/derphttp"
 	"tailscale.com/net/stun/stuntest"
-	"tailscale.com/smallzstd"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/logid"
 	"tailscale.com/types/nettype"
+	"tailscale.com/util/zstdframe"
 	"tailscale.com/version"
 )
 
@@ -302,20 +301,19 @@ func (lc *LogCatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		log.Printf("bad log ID: %q: %v", r.URL.Path, err)
 	}
 
-	var body io.Reader = r.Body
+	bodyBytes, err := io.ReadAll(r.Body)
+	if err != nil {
+		log.Printf("http.Request.Body.Read: %v", err)
+		return
+	}
 	if r.Header.Get("Content-Encoding") == "zstd" {
-		var err error
-		var dec *zstd.Decoder
-		dec, err = smallzstd.NewDecoder(body)
+		bodyBytes, err = zstdframe.AppendDecode(nil, bodyBytes)
 		if err != nil {
-			log.Printf("bad caught zstd: %v", err)
+			log.Printf("zstdframe.AppendDecode: %v", err)
 			http.Error(w, err.Error(), 400)
 			return
 		}
-		defer dec.Close()
-		body = dec
 	}
-	bodyBytes, _ := io.ReadAll(body)
 
 	type Entry struct {
 		Logtail struct {

+ 2 - 16
tstest/integration/testcontrol/testcontrol.go

@@ -24,11 +24,9 @@ import (
 	"sync"
 	"time"
 
-	"github.com/klauspost/compress/zstd"
 	"go4.org/mem"
 	"tailscale.com/net/netaddr"
 	"tailscale.com/net/tsaddr"
-	"tailscale.com/smallzstd"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
@@ -37,6 +35,7 @@ import (
 	"tailscale.com/util/must"
 	"tailscale.com/util/rands"
 	"tailscale.com/util/set"
+	"tailscale.com/util/zstdframe"
 )
 
 const msgLimit = 1 << 20 // encrypted message length limit
@@ -1047,16 +1046,6 @@ func (s *Server) decode(mkey key.MachinePublic, msg []byte, v any) error {
 	return json.Unmarshal(decrypted, v)
 }
 
-var zstdEncoderPool = &sync.Pool{
-	New: func() any {
-		encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
-		if err != nil {
-			panic(err)
-		}
-		return encoder
-	},
-}
-
 func (s *Server) encode(mkey key.MachinePublic, compress bool, v any) (b []byte, err error) {
 	var isBytes bool
 	if b, isBytes = v.([]byte); !isBytes {
@@ -1066,10 +1055,7 @@ func (s *Server) encode(mkey key.MachinePublic, compress bool, v any) (b []byte,
 		}
 	}
 	if compress {
-		encoder := zstdEncoderPool.Get().(*zstd.Encoder)
-		b = encoder.EncodeAll(b, nil)
-		encoder.Close()
-		zstdEncoderPool.Put(encoder)
+		b = zstdframe.AppendEncode(nil, b, zstdframe.FastestCompression)
 	}
 	return s.privateKey().SealTo(mkey, b), nil
 }

+ 2 - 9
wgengine/netlog/logger.go

@@ -22,7 +22,6 @@ import (
 	"tailscale.com/net/netmon"
 	"tailscale.com/net/sockstats"
 	"tailscale.com/net/tsaddr"
-	"tailscale.com/smallzstd"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/logid"
 	"tailscale.com/types/netlogtype"
@@ -111,15 +110,9 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
 		PrivateID:     nodeLogID,
 		CopyPrivateID: domainLogID,
 		Stderr:        io.Discard,
+		CompressLogs:  true,
+		HTTPC:         httpc,
 		// TODO(joetsai): Set Buffer? Use an in-memory buffer for now.
-		NewZstdEncoder: func() logtail.Encoder {
-			w, err := smallzstd.NewEncoder(nil)
-			if err != nil {
-				panic(err)
-			}
-			return w
-		},
-		HTTPC: httpc,
 
 		// Include process sequence numbers to identify missing samples.
 		IncludeProcID:       true,