Browse Source

net/connstats: enforce maximum number of connections (#6760)

The Tailscale logging service has a hard limit on the maximum
log message size that can be accepted.
We want to ensure that netlog messages never exceed
this limit otherwise a client cannot transmit logs.

Move the goroutine for periodically dumping netlog messages
from wgengine/netlog to net/connstats.
This allows net/connstats to manage when it dumps messages,
either based on time or by size.

Updates tailscale/corp#8427

Signed-off-by: Joe Tsai <[email protected]>
Joe Tsai 3 năm trước cách đây
mục cha
commit
d9df023e6f

+ 114 - 15
net/connstats/stats.go

@@ -7,9 +7,12 @@
 package connstats
 
 import (
+	"context"
 	"net/netip"
 	"sync"
+	"time"
 
+	"golang.org/x/sync/errgroup"
 	"tailscale.com/net/packet"
 	"tailscale.com/types/netlogtype"
 )
@@ -18,11 +21,64 @@ import (
 // All methods are safe for concurrent use.
 // The zero value is ready for use.
 type Statistics struct {
-	mu       sync.Mutex
+	maxConns int // immutable once set
+
+	mu sync.Mutex
+	connCnts
+
+	connCntsCh  chan connCnts
+	shutdownCtx context.Context
+	shutdown    context.CancelFunc
+	group       errgroup.Group
+}
+
+type connCnts struct {
+	start    time.Time
+	end      time.Time
 	virtual  map[netlogtype.Connection]netlogtype.Counts
 	physical map[netlogtype.Connection]netlogtype.Counts
 }
 
+// NewStatistics creates a data structure for tracking connection statistics
+// that periodically dumps the virtual and physical connection counts
+// depending on whether the maxPeriod or maxConns is exceeded.
+// The dump function is called from a single goroutine.
+// Shutdown must be called to cleanup resources.
+func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics {
+	s := &Statistics{maxConns: maxConns}
+	s.connCntsCh = make(chan connCnts, 256)
+	s.shutdownCtx, s.shutdown = context.WithCancel(context.Background())
+	s.group.Go(func() error {
+		// TODO(joetsai): Using a ticker is problematic on mobile platforms
+		// where waking up a process every maxPeriod when there is no activity
+		// is a drain on battery life. Switch this instead to instead use
+		// a time.Timer that is triggered upon network activity.
+		ticker := new(time.Ticker)
+		if maxPeriod > 0 {
+			ticker := time.NewTicker(maxPeriod)
+			defer ticker.Stop()
+		}
+
+		for {
+			var cc connCnts
+			select {
+			case cc = <-s.connCntsCh:
+			case <-ticker.C:
+				cc = s.extract()
+			case <-s.shutdownCtx.Done():
+				cc = s.extract()
+			}
+			if len(cc.virtual)+len(cc.physical) > 0 && dump != nil {
+				dump(cc.start, cc.end, cc.virtual, cc.physical)
+			}
+			if s.shutdownCtx.Err() != nil {
+				return nil
+			}
+		}
+	})
+	return s
+}
+
 // UpdateTxVirtual updates the counters for a transmitted IP packet
 // The source and destination of the packet directly correspond with
 // the source and destination in netlogtype.Connection.
@@ -47,10 +103,10 @@ func (s *Statistics) updateVirtual(b []byte, receive bool) {
 
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if s.virtual == nil {
-		s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
+	cnts, found := s.virtual[conn]
+	if !found && !s.preInsertConn() {
+		return
 	}
-	cnts := s.virtual[conn]
 	if receive {
 		cnts.RxPackets++
 		cnts.RxBytes += uint64(len(b))
@@ -82,10 +138,10 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r
 
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if s.physical == nil {
-		s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
+	cnts, found := s.physical[conn]
+	if !found && !s.preInsertConn() {
+		return
 	}
-	cnts := s.physical[conn]
 	if receive {
 		cnts.RxPackets++
 		cnts.RxBytes += uint64(n)
@@ -96,14 +152,57 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r
 	s.physical[conn] = cnts
 }
 
-// Extract extracts and resets the counters for all active connections.
-// It must be called periodically otherwise the memory used is unbounded.
-func (s *Statistics) Extract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
+// preInsertConn updates the maps to handle insertion of a new connection.
+// It reports false if insertion is not allowed (i.e., after shutdown).
+func (s *Statistics) preInsertConn() bool {
+	// Check whether insertion of a new connection will exceed maxConns.
+	if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 {
+		// Extract the current statistics and send it to the serializer.
+		// Avoid blocking the network packet handling path.
+		select {
+		case s.connCntsCh <- s.extractLocked():
+		default:
+			// TODO(joetsai): Log that we are dropping an entire connCounts.
+		}
+	}
+
+	// Initialize the maps if nil.
+	if s.virtual == nil && s.physical == nil {
+		s.start = time.Now().UTC()
+		s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
+		s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
+	}
+
+	return s.shutdownCtx.Err() == nil
+}
+
+func (s *Statistics) extract() connCnts {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	virtual = s.virtual
-	s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
-	physical = s.physical
-	s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
-	return virtual, physical
+	return s.extractLocked()
+}
+
+func (s *Statistics) extractLocked() connCnts {
+	if len(s.virtual)+len(s.physical) == 0 {
+		return connCnts{}
+	}
+	s.end = time.Now().UTC()
+	cc := s.connCnts
+	s.connCnts = connCnts{}
+	return cc
+}
+
+// TestExtract synchronously extracts the current network statistics map
+// and resets the counters. This should only be used for testing purposes.
+func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
+	cc := s.extract()
+	return cc.virtual, cc.physical
+}
+
+// Shutdown performs a final flush of statistics.
+// Statistics for any subsequent calls to Update will be dropped.
+// It is safe to call Shutdown concurrently and repeatedly.
+func (s *Statistics) Shutdown(context.Context) error {
+	s.shutdown()
+	return s.group.Wait()
 }

+ 21 - 12
net/connstats/stats_test.go

@@ -5,6 +5,7 @@
 package connstats
 
 import (
+	"context"
 	"encoding/binary"
 	"fmt"
 	"math/rand"
@@ -47,7 +48,20 @@ func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPor
 func TestConcurrent(t *testing.T) {
 	c := qt.New(t)
 
-	var stats Statistics
+	const maxPeriod = 10 * time.Millisecond
+	const maxConns = 10
+	virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts)
+	stats := NewStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
+		c.Assert(start.IsZero(), qt.IsFalse)
+		c.Assert(end.IsZero(), qt.IsFalse)
+		c.Assert(end.Before(start), qt.IsFalse)
+		c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue)
+		c.Assert(len(physical) == 0, qt.IsTrue)
+		for conn, cnts := range virtual {
+			virtualAggregate[conn] = virtualAggregate[conn].Add(cnts)
+		}
+	})
+	defer stats.Shutdown(context.Background())
 	var wants []map[netlogtype.Connection]netlogtype.Counts
 	gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU())
 	var group sync.WaitGroup
@@ -95,14 +109,9 @@ func TestConcurrent(t *testing.T) {
 			}
 		}(i)
 	}
-	for range gots {
-		virtual, _ := stats.Extract()
-		wants = append(wants, virtual)
-		time.Sleep(time.Millisecond)
-	}
 	group.Wait()
-	virtual, _ := stats.Extract()
-	wants = append(wants, virtual)
+	c.Assert(stats.Shutdown(context.Background()), qt.IsNil)
+	wants = append(wants, virtualAggregate)
 
 	got := make(map[netlogtype.Connection]netlogtype.Counts)
 	want := make(map[netlogtype.Connection]netlogtype.Counts)
@@ -126,7 +135,7 @@ func Benchmark(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			var s Statistics
+			s := NewStatistics(0, 0, nil)
 			for j := 0; j < 1e3; j++ {
 				s.UpdateTxVirtual(p)
 			}
@@ -137,7 +146,7 @@ func Benchmark(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			var s Statistics
+			s := NewStatistics(0, 0, nil)
 			for j := 0; j < 1e3; j++ {
 				binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
 				s.UpdateTxVirtual(p)
@@ -149,7 +158,7 @@ func Benchmark(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			var s Statistics
+			s := NewStatistics(0, 0, nil)
 			var group sync.WaitGroup
 			for j := 0; j < runtime.NumCPU(); j++ {
 				group.Add(1)
@@ -171,7 +180,7 @@ func Benchmark(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			var s Statistics
+			s := NewStatistics(0, 0, nil)
 			var group sync.WaitGroup
 			for j := 0; j < runtime.NumCPU(); j++ {
 				group.Add(1)

+ 9 - 6
net/tstun/wrap_test.go

@@ -6,15 +6,17 @@ package tstun
 
 import (
 	"bytes"
+	"context"
 	"encoding/binary"
 	"fmt"
 	"net/netip"
-	"reflect"
 	"strconv"
 	"strings"
 	"testing"
 	"unsafe"
 
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
 	"github.com/tailscale/wireguard-go/tun/tuntest"
 	"go4.org/mem"
 	"go4.org/netipx"
@@ -337,7 +339,8 @@ func TestFilter(t *testing.T) {
 	}()
 
 	var buf [MaxPacketSize]byte
-	stats := new(connstats.Statistics)
+	stats := connstats.NewStatistics(0, 0, nil)
+	defer stats.Shutdown(context.Background())
 	tun.SetStatistics(stats)
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
@@ -346,7 +349,7 @@ func TestFilter(t *testing.T) {
 			var filtered bool
 			sizes := make([]int, 1)
 
-			tunStats, _ := stats.Extract()
+			tunStats, _ := stats.TestExtract()
 			if len(tunStats) > 0 {
 				t.Errorf("connstats.Statistics.Extract = %v, want {}", stats)
 			}
@@ -381,7 +384,7 @@ func TestFilter(t *testing.T) {
 				}
 			}
 
-			got, _ := stats.Extract()
+			got, _ := stats.TestExtract()
 			want := map[netlogtype.Connection]netlogtype.Counts{}
 			if !tt.drop {
 				var p packet.Parsed
@@ -395,8 +398,8 @@ func TestFilter(t *testing.T) {
 					want[conn] = netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(tt.data))}
 				}
 			}
-			if !reflect.DeepEqual(got, want) {
-				t.Errorf("tun.ExtractStatistics = %v, want %v", got, want)
+			if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
+				t.Errorf("stats.TestExtract (-got +want):\n%s", diff)
 			}
 		})
 	}

+ 9 - 0
types/netlogtype/netlogtype.go

@@ -30,6 +30,15 @@ type Message struct {
 }
 
 const (
+	messageJSON      = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}`
+	maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339
+	maxJSONRFC3339   = `"0001-01-01T00:00:00.000000000Z"`
+	minJSONTraffic   = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}`
+
+	// MaxMessageJSONSize is the overhead size of Message when it is
+	// serialized as JSON assuming that each traffic map is populated.
+	MaxMessageJSONSize = len(messageJSON)
+
 	maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}`
 	maxJSONConn       = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort
 	maxJSONProto      = `255`

+ 8 - 4
wgengine/magicsock/magicsock_test.go

@@ -135,7 +135,7 @@ func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, st
 type magicStack struct {
 	privateKey key.NodePrivate
 	epCh       chan []tailcfg.Endpoint // endpoint updates produced by this peer
-	stats      connstats.Statistics    // per-connection statistics
+	stats      *connstats.Statistics   // per-connection statistics
 	conn       *Conn                   // the magicsock itself
 	tun        *tuntest.ChannelTUN     // TUN device to send/receive packets
 	tsTun      *tstun.Wrapper          // wrapped tun that implements filtering and wgengine hooks
@@ -1053,11 +1053,15 @@ func testTwoDevicePing(t *testing.T, d *devices) {
 		}
 	}
 
-	m1.conn.SetStatistics(&m1.stats)
-	m2.conn.SetStatistics(&m2.stats)
+	m1.stats = connstats.NewStatistics(0, 0, nil)
+	defer m1.stats.Shutdown(context.Background())
+	m1.conn.SetStatistics(m1.stats)
+	m2.stats = connstats.NewStatistics(0, 0, nil)
+	defer m2.stats.Shutdown(context.Background())
+	m2.conn.SetStatistics(m2.stats)
 
 	checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) {
-		_, stats := m.stats.Extract()
+		_, stats := m.stats.TestExtract()
 		for _, conn := range wantConns {
 			if _, ok := stats[conn]; ok {
 				return

+ 45 - 69
wgengine/netlog/logger.go

@@ -17,7 +17,6 @@ import (
 	"sync"
 	"time"
 
-	"golang.org/x/sync/errgroup"
 	"tailscale.com/logpolicy"
 	"tailscale.com/logtail"
 	"tailscale.com/net/connstats"
@@ -25,6 +24,7 @@ import (
 	"tailscale.com/smallzstd"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/netlogtype"
+	"tailscale.com/util/multierr"
 	"tailscale.com/wgengine/router"
 )
 
@@ -32,8 +32,7 @@ import (
 const pollPeriod = 5 * time.Second
 
 // Device is an abstraction over a tunnel device or a magic socket.
-// *tstun.Wrapper implements this interface.
-// *magicsock.Conn implements this interface.
+// Both *tstun.Wrapper and *magicsock.Conn implement this interface.
 type Device interface {
 	SetStatistics(*connstats.Statistics)
 }
@@ -47,15 +46,15 @@ func (noopDevice) SetStatistics(*connstats.Statistics) {}
 // Exit node traffic is not logged for privacy reasons.
 // The zero value is ready for use.
 type Logger struct {
-	mu sync.Mutex
+	mu sync.Mutex // protects all fields below
 
 	logger *logtail.Logger
+	stats  *connstats.Statistics
+	tun    Device
+	sock   Device
 
 	addrs    map[netip.Addr]bool
 	prefixes map[netip.Prefix]bool
-
-	group  errgroup.Group
-	cancel context.CancelFunc
 }
 
 // Running reports whether the logger is running.
@@ -97,18 +96,13 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
 	if nl.logger != nil {
 		return fmt.Errorf("network logger already running for %v", nl.logger.PrivateID().Public())
 	}
-	if tun == nil {
-		tun = noopDevice{}
-	}
-	if sock == nil {
-		sock = noopDevice{}
-	}
 
+	// Startup a log stream to Tailscale's logging service.
 	httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)}
 	if testClient != nil {
 		httpc = testClient
 	}
-	logger := logtail.NewLogger(logtail.Config{
+	nl.logger = logtail.NewLogger(logtail.Config{
 		Collection:    "tailtraffic.log.tailscale.io",
 		PrivateID:     nodeLogID,
 		CopyPrivateID: domainLogID,
@@ -127,47 +121,34 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
 		IncludeProcID:       true,
 		IncludeProcSequence: true,
 	}, log.Printf)
-	nl.logger = logger
-
-	stats := new(connstats.Statistics)
-	ctx, cancel := context.WithCancel(context.Background())
-	nl.cancel = cancel
-	nl.group.Go(func() error {
-		tun.SetStatistics(stats)
-		defer tun.SetStatistics(nil)
 
-		sock.SetStatistics(stats)
-		defer sock.SetStatistics(nil)
+	// Startup a data structure to track per-connection statistics.
+	// There is a maximum size for individual log messages that logtail
+	// can upload to the Tailscale log service, so stay below this limit.
+	const maxLogSize = 256 << 10
+	const maxConns = (maxLogSize - netlogtype.MaxMessageJSONSize) / netlogtype.MaxConnectionCountsJSONSize
+	nl.stats = connstats.NewStatistics(pollPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
+		nl.mu.Lock()
+		addrs := nl.addrs
+		prefixes := nl.prefixes
+		nl.mu.Unlock()
+		recordStatistics(nl.logger, nodeID, start, end, virtual, physical, addrs, prefixes)
+	})
 
-		start := time.Now()
-		ticker := time.NewTicker(pollPeriod)
-		for {
-			var end time.Time
-			select {
-			case <-ctx.Done():
-				end = time.Now()
-			case end = <-ticker.C:
-			}
+	// Register the connection tracker into the TUN device.
+	if tun == nil {
+		tun = noopDevice{}
+	}
+	nl.tun = tun
+	nl.tun.SetStatistics(nl.stats)
 
-			// NOTE: connstats and sockStats will always be slightly out-of-sync.
-			// It is impossible to have an atomic snapshot of statistics
-			// at both layers without a global mutex that spans all layers.
-			connstats, sockStats := stats.Extract()
-			if len(connstats)+len(sockStats) > 0 {
-				nl.mu.Lock()
-				addrs := nl.addrs
-				prefixes := nl.prefixes
-				nl.mu.Unlock()
-				recordStatistics(logger, nodeID, start, end, connstats, sockStats, addrs, prefixes)
-			}
+	// Register the connection tracker into magicsock.
+	if sock == nil {
+		sock = noopDevice{}
+	}
+	nl.sock = sock
+	nl.sock.SetStatistics(nl.stats)
 
-			if ctx.Err() != nil {
-				break
-			}
-			start = end.Add(time.Nanosecond)
-		}
-		return nil
-	})
 	return nil
 }
 
@@ -222,21 +203,8 @@ func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start
 	}
 
 	if len(m.VirtualTraffic)+len(m.SubnetTraffic)+len(m.ExitTraffic)+len(m.PhysicalTraffic) > 0 {
-		// TODO(joetsai): Place a hard limit on the size of a network log message.
-		// The log server rejects any payloads above a certain size, so logging
-		// a message that large would cause logtail to be stuck forever trying
-		// and failing to upload the same excessively large payload.
-		//
-		// We should figure out the behavior for handling this. We could split
-		// the message apart so that there are multiple chunks with the same window,
-		// We could also consider reducing the granularity of the data
-		// by dropping port numbers.
-		const maxSize = 256 << 10
 		if b, err := json.Marshal(m); err != nil {
 			logger.Logf("json.Marshal error: %v", err)
-		} else if len(b) > maxSize {
-			logger.Logf("JSON body too large: %dB (virtual:%d subnet:%d exit:%d physical:%d)",
-				len(b), len(m.VirtualTraffic), len(m.SubnetTraffic), len(m.ExitTraffic), len(m.PhysicalTraffic))
 		} else {
 			logger.Logf("%s", b)
 		}
@@ -285,15 +253,23 @@ func (nl *Logger) Shutdown(ctx context.Context) error {
 	if nl.logger == nil {
 		return nil
 	}
-	nl.cancel()
+
+	// Shutdown in reverse order of Startup.
+	// Do not hold lock while shutting down since this may flush one last time.
 	nl.mu.Unlock()
-	nl.group.Wait() // do not hold lock while waiting
+	nl.sock.SetStatistics(nil)
+	nl.tun.SetStatistics(nil)
+	err1 := nl.stats.Shutdown(ctx)
+	err2 := nl.logger.Shutdown(ctx)
 	nl.mu.Lock()
-	err := nl.logger.Shutdown(ctx)
 
+	// Purge state.
 	nl.logger = nil
+	nl.stats = nil
+	nl.tun = nil
+	nl.sock = nil
 	nl.addrs = nil
 	nl.prefixes = nil
-	nl.cancel = nil
-	return err
+
+	return multierr.New(err1, err2)
 }