Browse Source

cmd/tailscaled,net/tstun: fix data race on start-up in TUN mode

Fixes #7894

Change-Id: Ice3f8019405714dd69d02bc07694f3872bb598b8

Co-authored-by: Brad Fitzpatrick <[email protected]>
Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 2 years ago
parent
commit
5297bd2cff

+ 4 - 0
cmd/tailscaled/tailscaled.go

@@ -540,6 +540,10 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID
 	}
 	}
 	sys.Set(store)
 	sys.Set(store)
 
 
+	if w, ok := sys.Tun.GetOK(); ok {
+		w.Start()
+	}
+
 	lb, err := ipnlocal.NewLocalBackend(logf, logID, sys, opts.LoginFlags)
 	lb, err := ipnlocal.NewLocalBackend(logf, logID, sys, opts.LoginFlags)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("ipnlocal.NewLocalBackend: %w", err)
 		return nil, fmt.Errorf("ipnlocal.NewLocalBackend: %w", err)

+ 1 - 0
net/tstun/fake.go

@@ -55,3 +55,4 @@ func (t *fakeTUN) MTU() (int, error)        { return 1500, nil }
 func (t *fakeTUN) Name() (string, error)    { return FakeTUNName, nil }
 func (t *fakeTUN) Name() (string, error)    { return FakeTUNName, nil }
 func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
 func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan }
 func (t *fakeTUN) BatchSize() int           { return 1 }
 func (t *fakeTUN) BatchSize() int           { return 1 }
+func (t *fakeTUN) IsFakeTun() bool          { return true }

+ 23 - 0
net/tstun/wrap.go

@@ -78,6 +78,9 @@ var parsedPacketPool = sync.Pool{New: func() any { return new(packet.Parsed) }}
 type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response
 type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response
 
 
 // Wrapper augments a tun.Device with packet filtering and injection.
 // Wrapper augments a tun.Device with packet filtering and injection.
+//
+// A Wrapper starts in a "corked" mode where Read calls are blocked
+// until the Wrapper's Start method is called.
 type Wrapper struct {
 type Wrapper struct {
 	logf        logger.Logf
 	logf        logger.Logf
 	limitedLogf logger.Logf // aggressively rate-limited logf used for potentially high volume errors
 	limitedLogf logger.Logf // aggressively rate-limited logf used for potentially high volume errors
@@ -85,6 +88,9 @@ type Wrapper struct {
 	tdev  tun.Device
 	tdev  tun.Device
 	isTAP bool // whether tdev is a TAP device
 	isTAP bool // whether tdev is a TAP device
 
 
+	started atomic.Bool   // whether Start has been called
+	startCh chan struct{} // closed in Start
+
 	closeOnce sync.Once
 	closeOnce sync.Once
 
 
 	// lastActivityAtomic is read/written atomically.
 	// lastActivityAtomic is read/written atomically.
@@ -219,6 +225,16 @@ type setWrapperer interface {
 	setWrapper(*Wrapper)
 	setWrapper(*Wrapper)
 }
 }
 
 
+// Start unblocks any Wrapper.Read calls that have already started
+// and makes the Wrapper functional.
+//
+// Start must be called exactly once after the various Tailscale
+// subsystems have been wired up to each other.
+func (w *Wrapper) Start() {
+	w.started.Store(true)
+	close(w.startCh)
+}
+
 func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper {
 func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper {
 	return wrap(logf, tdev, true)
 	return wrap(logf, tdev, true)
 }
 }
@@ -244,6 +260,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper {
 		eventsOther:    make(chan tun.Event),
 		eventsOther:    make(chan tun.Event),
 		// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
 		// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
 		filterFlags: filter.LogAccepts | filter.LogDrops,
 		filterFlags: filter.LogAccepts | filter.LogDrops,
+		startCh:     make(chan struct{}),
 	}
 	}
 
 
 	w.vectorBuffer = make([][]byte, tdev.BatchSize())
 	w.vectorBuffer = make([][]byte, tdev.BatchSize())
@@ -309,6 +326,9 @@ func (t *Wrapper) isSelfDisco(p *packet.Parsed) bool {
 func (t *Wrapper) Close() error {
 func (t *Wrapper) Close() error {
 	var err error
 	var err error
 	t.closeOnce.Do(func() {
 	t.closeOnce.Do(func() {
+		if t.started.CompareAndSwap(false, true) {
+			close(t.startCh)
+		}
 		close(t.closed)
 		close(t.closed)
 		t.bufferConsumedMu.Lock()
 		t.bufferConsumedMu.Lock()
 		t.bufferConsumedClosed = true
 		t.bufferConsumedClosed = true
@@ -836,6 +856,9 @@ func (t *Wrapper) IdleDuration() time.Duration {
 }
 }
 
 
 func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
 func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
+	if !t.started.Load() {
+		<-t.startCh
+	}
 	// packet from OS read and sent to WG
 	// packet from OS read and sent to WG
 	res, ok := <-t.vectorOutbound
 	res, ok := <-t.vectorOutbound
 	if !ok {
 	if !ok {

+ 1 - 0
net/tstun/wrap_test.go

@@ -178,6 +178,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper
 	} else {
 	} else {
 		tun.disableFilter = true
 		tun.disableFilter = true
 	}
 	}
+	tun.Start()
 	return chtun, tun
 	return chtun, tun
 }
 }
 
 

+ 11 - 2
tsd/tsd.go

@@ -47,6 +47,10 @@ type System struct {
 	StateStore     SubSystem[ipn.StateStore]
 	StateStore     SubSystem[ipn.StateStore]
 	Netstack       SubSystem[NetstackImpl] // actually a *netstack.Impl
 	Netstack       SubSystem[NetstackImpl] // actually a *netstack.Impl
 
 
+	// onlyNetstack is whether the Tun value is a fake TUN device
+	// and we're using netstack for everything.
+	onlyNetstack bool
+
 	controlKnobs controlknobs.Knobs
 	controlKnobs controlknobs.Knobs
 	proxyMap     proxymap.Mapper
 	proxyMap     proxymap.Mapper
 }
 }
@@ -74,6 +78,12 @@ func (s *System) Set(v any) {
 	case router.Router:
 	case router.Router:
 		s.Router.Set(v)
 		s.Router.Set(v)
 	case *tstun.Wrapper:
 	case *tstun.Wrapper:
+		type ft interface {
+			IsFakeTun() bool
+		}
+		if _, ok := v.Unwrap().(ft); ok {
+			s.onlyNetstack = true
+		}
 		s.Tun.Set(v)
 		s.Tun.Set(v)
 	case *magicsock.Conn:
 	case *magicsock.Conn:
 		s.MagicSock.Set(v)
 		s.MagicSock.Set(v)
@@ -97,8 +107,7 @@ func (s *System) IsNetstackRouter() bool {
 
 
 // IsNetstack reports whether Tailscale is running as a netstack-based TUN-free engine.
 // IsNetstack reports whether Tailscale is running as a netstack-based TUN-free engine.
 func (s *System) IsNetstack() bool {
 func (s *System) IsNetstack() bool {
-	name, _ := s.Tun.Get().Name()
-	return name == tstun.FakeTUNName
+	return s.onlyNetstack
 }
 }
 
 
 // ControlKnobs returns the control knobs for this node.
 // ControlKnobs returns the control knobs for this node.

+ 1 - 0
tsnet/tsnet.go

@@ -530,6 +530,7 @@ func (s *Server) start() (reterr error) {
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("netstack.Create: %w", err)
 		return fmt.Errorf("netstack.Create: %w", err)
 	}
 	}
+	sys.Tun.Get().Start()
 	sys.Set(ns)
 	sys.Set(ns)
 	ns.ProcessLocalIPs = true
 	ns.ProcessLocalIPs = true
 	ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
 	ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow

+ 1 - 0
wgengine/magicsock/magicsock_test.go

@@ -184,6 +184,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen
 	tun := tuntest.NewChannelTUN()
 	tun := tuntest.NewChannelTUN()
 	tsTun := tstun.Wrap(logf, tun.TUN())
 	tsTun := tstun.Wrap(logf, tun.TUN())
 	tsTun.SetFilter(filter.NewAllowAllForTest(logf))
 	tsTun.SetFilter(filter.NewAllowAllForTest(logf))
+	tsTun.Start()
 
 
 	wgLogger := wglog.NewLogger(logf)
 	wgLogger := wglog.NewLogger(logf)
 	dev := wgcfg.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger)
 	dev := wgcfg.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger)