Browse Source

sockstats: add validation for TCP socket stats

We can use the TCP_CONNECTION_INFO getsockopt() on Darwin to get
OS-collected tx/rx bytes for TCP sockets. Since this API is not available
for UDP sockets (or on Linux/Android), we can't rely on it for actual
stats gathering.

However, we can use it to validate the stats that we collect ourselves
using read/write hooks, so that we can be more confident in them. We
do need additional hooks from the Go standard library (added in
tailscale/go#59) to be able to collect them.

Updates tailscale/corp#9230
Updates #3363

Signed-off-by: Mihai Parparita <[email protected]>
Mihai Parparita 3 years ago
parent
commit
f4f8ed98d9

+ 1 - 1
go.toolchain.rev

@@ -1 +1 @@
-fb11c0df588717a3ee13b09dacae1e7093279d67
+db4dc9046c93dde2c0e534ca7d529bd690ad09c9

+ 17 - 4
ipn/ipnlocal/peerapi.go

@@ -876,6 +876,7 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
 		fmt.Fprintf(w, "<th>Tx (%s)</th>", html.EscapeString(iface))
 		fmt.Fprintf(w, "<th>Rx (%s)</th>", html.EscapeString(iface))
 	}
+	fmt.Fprintln(w, "<th>Validation</th>")
 	fmt.Fprintln(w, "</thead>")
 
 	fmt.Fprintln(w, "<tbody>")
@@ -887,10 +888,10 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
 		return a.String() < b.String()
 	})
 
-	txTotal := int64(0)
-	rxTotal := int64(0)
-	txTotalByInterface := map[string]int64{}
-	rxTotalByInterface := map[string]int64{}
+	txTotal := uint64(0)
+	rxTotal := uint64(0)
+	txTotalByInterface := map[string]uint64{}
+	rxTotalByInterface := map[string]uint64{}
 
 	for _, label := range labels {
 		stat := stats.Stats[label]
@@ -908,6 +909,17 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
 			txTotalByInterface[iface] += stat.TxBytesByInterface[iface]
 			rxTotalByInterface[iface] += stat.RxBytesByInterface[iface]
 		}
+
+		if stat.ValidationRxBytes > 0 || stat.ValidationTxBytes > 0 {
+			fmt.Fprintf(w, "<td>Tx=%d (%+d) Rx=%d (%+d)</td>",
+				stat.ValidationTxBytes,
+				int64(stat.ValidationTxBytes)-int64(stat.TxBytes),
+				stat.ValidationRxBytes,
+				int64(stat.ValidationRxBytes)-int64(stat.RxBytes))
+		} else {
+			fmt.Fprintln(w, "<td></td>")
+		}
+
 		fmt.Fprintln(w, "</tr>")
 	}
 	fmt.Fprintln(w, "</tbody>")
@@ -920,6 +932,7 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
 		fmt.Fprintf(w, "<th>%d</th>", txTotalByInterface[iface])
 		fmt.Fprintf(w, "<th>%d</th>", rxTotalByInterface[iface])
 	}
+	fmt.Fprintln(w, "<th></th>")
 	fmt.Fprintln(w, "</tfoot>")
 
 	fmt.Fprintln(w, "</table>")

+ 8 - 4
net/sockstats/sockstats.go

@@ -42,10 +42,14 @@ const (
 )
 
 type SockStat struct {
-	TxBytes            int64
-	RxBytes            int64
-	TxBytesByInterface map[string]int64
-	RxBytesByInterface map[string]int64
+	TxBytes            uint64
+	RxBytes            uint64
+	TxBytesByInterface map[string]uint64
+	RxBytesByInterface map[string]uint64
+
+	// NOCOMMIT
+	ValidationTxBytes uint64
+	ValidationRxBytes uint64
 }
 
 func WithSockStats(ctx context.Context, label Label) context.Context {

+ 50 - 10
net/sockstats/sockstats_tsgo.go

@@ -11,6 +11,7 @@ import (
 	"net"
 	"sync"
 	"sync/atomic"
+	"syscall"
 
 	"tailscale.com/net/interfaces"
 )
@@ -18,6 +19,12 @@ import (
 type sockStatCounters struct {
 	txBytes, rxBytes                       atomic.Uint64
 	rxBytesByInterface, txBytesByInterface map[int]*atomic.Uint64
+
+	// Validate counts for TCP sockets by using the TCP_CONNECTION_INFO
+	// getsockopt. We get current counts, as well as save final values when
+	// sockets are closed.
+	validationConn                       atomic.Pointer[syscall.RawConn]
+	validationTxBytes, validationRxBytes atomic.Uint64
 }
 
 var sockStats = struct {
@@ -53,6 +60,23 @@ func withSockStats(ctx context.Context, label Label) context.Context {
 		sockStats.countersByLabel[label] = counters
 	}
 
+	didCreateTCPConn := func(c syscall.RawConn) {
+		counters.validationConn.Store(&c)
+	}
+
+	willCloseTCPConn := func(c syscall.RawConn) {
+		tx, rx := tcpConnStats(c)
+		counters.validationTxBytes.Add(tx)
+		counters.validationRxBytes.Add(rx)
+	}
+
+	// Don't bother adding these hooks if we can't get stats that they end up
+	// collecting.
+	if tcpConnStats == nil {
+		willCloseTCPConn = nil
+		didCreateTCPConn = nil
+	}
+
 	didRead := func(n int) {
 		counters.rxBytes.Add(uint64(n))
 		if currentInterface := int(sockStats.currentInterface.Load()); currentInterface != 0 {
@@ -74,12 +98,19 @@ func withSockStats(ctx context.Context, label Label) context.Context {
 	}
 
 	return net.WithSockTrace(ctx, &net.SockTrace{
-		DidRead:       didRead,
-		DidWrite:      didWrite,
-		WillOverwrite: willOverwrite,
+		DidCreateTCPConn: didCreateTCPConn,
+		DidRead:          didRead,
+		DidWrite:         didWrite,
+		WillOverwrite:    willOverwrite,
+		WillCloseTCPConn: willCloseTCPConn,
 	})
 }
 
+// tcpConnStats returns the number of bytes sent and received on the
+// given TCP socket. Its implementation is platform-dependent (or it may not
+// be available at all).
+var tcpConnStats func(c syscall.RawConn) (tx, rx uint64)
+
 func get() *SockStats {
 	sockStats.mu.Lock()
 	defer sockStats.mu.Unlock()
@@ -93,20 +124,29 @@ func get() *SockStats {
 	}
 
 	for label, counters := range sockStats.countersByLabel {
-		r.Stats[label] = SockStat{
-			TxBytes:            int64(counters.txBytes.Load()),
-			RxBytes:            int64(counters.rxBytes.Load()),
-			TxBytesByInterface: make(map[string]int64),
-			RxBytesByInterface: make(map[string]int64),
+		s := SockStat{
+			TxBytes:            counters.txBytes.Load(),
+			RxBytes:            counters.rxBytes.Load(),
+			TxBytesByInterface: make(map[string]uint64),
+			RxBytesByInterface: make(map[string]uint64),
+
+			ValidationTxBytes: counters.validationTxBytes.Load(),
+			ValidationRxBytes: counters.validationRxBytes.Load(),
+		}
+		if c := counters.validationConn.Load(); c != nil && tcpConnStats != nil {
+			tx, rx := tcpConnStats(*c)
+			s.ValidationTxBytes += tx
+			s.ValidationRxBytes += rx
 		}
 		for iface, a := range counters.rxBytesByInterface {
 			ifName := sockStats.knownInterfaces[iface]
-			r.Stats[label].RxBytesByInterface[ifName] = int64(a.Load())
+			s.RxBytesByInterface[ifName] = a.Load()
 		}
 		for iface, a := range counters.txBytesByInterface {
 			ifName := sockStats.knownInterfaces[iface]
-			r.Stats[label].TxBytesByInterface[ifName] = int64(a.Load())
+			s.TxBytesByInterface[ifName] = a.Load()
 		}
+		r.Stats[label] = s
 	}
 
 	return r

+ 30 - 0
net/sockstats/sockstats_tsgo_darwin.go

@@ -0,0 +1,30 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build tailscale_go && (darwin || ios)
+
+package sockstats
+
+import (
+	"syscall"
+
+	"golang.org/x/sys/unix"
+)
+
+func init() {
+	tcpConnStats = darwinTcpConnStats
+}
+
+func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) {
+	c.Control(func(fd uintptr) {
+		if rawInfo, err := unix.GetsockoptTCPConnectionInfo(
+			int(fd),
+			unix.IPPROTO_TCP,
+			unix.TCP_CONNECTION_INFO,
+		); err == nil {
+			tx = uint64(rawInfo.Txbytes)
+			rx = uint64(rawInfo.Rxbytes)
+		}
+	})
+	return
+}