Parcourir la source

util/eventbus/eventbustest: add support for synctest instead of timers (#17522)

Before synctest, timers was needed to allow the events to flow into the
test bus. There is still a timer, but this one is not derived from the
test deadline and it is mostly arbitrary as synctest will render it
practically non-existent.

With this approach, tests that do not need to test for the absence of
events do not rely on synctest.

Updates #15160

Signed-off-by: Claus Lensbøl <[email protected]>
Claus Lensbøl il y a 4 mois
Parent
commit
005e264b54

+ 66 - 51
health/health_test.go

@@ -5,12 +5,14 @@ package health
 
 import (
 	"errors"
+	"flag"
 	"fmt"
 	"maps"
 	"reflect"
 	"slices"
 	"strconv"
 	"testing"
+	"testing/synctest"
 	"time"
 
 	"github.com/google/go-cmp/cmp"
@@ -26,6 +28,8 @@ import (
 	"tailscale.com/version"
 )
 
+var doDebug = flag.Bool("debug", false, "Enable debug logging")
+
 func wantChange(c Change) func(c Change) (bool, error) {
 	return func(cEv Change) (bool, error) {
 		if cEv.ControlHealthChanged != c.ControlHealthChanged {
@@ -724,72 +728,83 @@ func TestControlHealthNotifies(t *testing.T) {
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			bus := eventbustest.NewBus(t)
-			tw := eventbustest.NewWatcher(t, bus)
-			tw.TimeOut = time.Second
-
-			ht := NewTracker(bus)
-			ht.SetIPNState("NeedsLogin", true)
-			ht.GotStreamedMapResponse()
-
-			// Expect events at starup, before doing anything else
-			if err := eventbustest.ExpectExactly(tw,
-				eventbustest.Type[Change](), // warming-up
-				eventbustest.Type[Change](), // is-using-unstable-version
-				eventbustest.Type[Change](), // not-in-map-poll
-			); err != nil {
-				t.Errorf("startup error: %v", err)
-			}
+			synctest.Test(t, func(t *testing.T) {
+				bus := eventbustest.NewBus(t)
+				if *doDebug {
+					eventbustest.LogAllEvents(t, bus)
+				}
+				tw := eventbustest.NewWatcher(t, bus)
+
+				ht := NewTracker(bus)
+				ht.SetIPNState("NeedsLogin", true)
+				ht.GotStreamedMapResponse()
+
+				// Expect events at starup, before doing anything else
+				synctest.Wait()
+				if err := eventbustest.ExpectExactly(tw,
+					eventbustest.Type[Change](), // warming-up
+					eventbustest.Type[Change](), // is-using-unstable-version
+					eventbustest.Type[Change](), // not-in-map-poll
+				); err != nil {
+					t.Errorf("startup error: %v", err)
+				}
 
-			// Only set initial state if we need to
-			if len(test.initialState) != 0 {
-				ht.SetControlHealth(test.initialState)
-				if err := eventbustest.ExpectExactly(tw, eventbustest.Type[Change]()); err != nil {
-					t.Errorf("initial state error: %v", err)
+				// Only set initial state if we need to
+				if len(test.initialState) != 0 {
+					ht.SetControlHealth(test.initialState)
+					synctest.Wait()
+					if err := eventbustest.ExpectExactly(tw, eventbustest.Type[Change]()); err != nil {
+						t.Errorf("initial state error: %v", err)
+					}
 				}
-			}
 
-			ht.SetControlHealth(test.newState)
+				ht.SetControlHealth(test.newState)
+				// Close the bus early to avoid timers triggering more events.
+				bus.Close()
 
-			if err := eventbustest.ExpectExactly(tw, test.wantEvents...); err != nil {
-				t.Errorf("event error: %v", err)
-			}
+				synctest.Wait()
+				if err := eventbustest.ExpectExactly(tw, test.wantEvents...); err != nil {
+					t.Errorf("event error: %v", err)
+				}
+			})
 		})
 	}
 }
 
 func TestControlHealthIgnoredOutsideMapPoll(t *testing.T) {
-	bus := eventbustest.NewBus(t)
-	tw := eventbustest.NewWatcher(t, bus)
-	tw.TimeOut = 100 * time.Millisecond
-	ht := NewTracker(bus)
-	ht.SetIPNState("NeedsLogin", true)
+	synctest.Test(t, func(t *testing.T) {
+		bus := eventbustest.NewBus(t)
+		tw := eventbustest.NewWatcher(t, bus)
+		ht := NewTracker(bus)
+		ht.SetIPNState("NeedsLogin", true)
 
-	ht.SetControlHealth(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{
-		"control-health": {},
-	})
+		ht.SetControlHealth(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{
+			"control-health": {},
+		})
 
-	state := ht.CurrentState()
-	_, ok := state.Warnings["control-health"]
+		state := ht.CurrentState()
+		_, ok := state.Warnings["control-health"]
 
-	if ok {
-		t.Error("got a warning with code 'control-health', want none")
-	}
+		if ok {
+			t.Error("got a warning with code 'control-health', want none")
+		}
 
-	// An event is emitted when SetIPNState is run above,
-	// so only fail on the second event.
-	eventCounter := 0
-	expectOne := func(c *Change) error {
-		eventCounter++
-		if eventCounter == 1 {
-			return nil
+		// An event is emitted when SetIPNState is run above,
+		// so only fail on the second event.
+		eventCounter := 0
+		expectOne := func(c *Change) error {
+			eventCounter++
+			if eventCounter == 1 {
+				return nil
+			}
+			return errors.New("saw more than 1 event")
 		}
-		return errors.New("saw more than 1 event")
-	}
 
-	if err := eventbustest.Expect(tw, expectOne); err == nil {
-		t.Error("event got emitted, want it to not be called")
-	}
+		synctest.Wait()
+		if err := eventbustest.Expect(tw, expectOne); err == nil {
+			t.Error("event got emitted, want it to not be called")
+		}
+	})
 }
 
 // TestCurrentStateETagControlHealth tests that the ETag on an [UnhealthyState]

+ 1 - 1
net/netmon/netmon_test.go

@@ -144,7 +144,7 @@ func TestMonitorMode(t *testing.T) {
 		<-done
 		t.Logf("%v callbacks", n)
 	case "eventbus":
-		tw.TimeOut = *monitorDuration
+		time.AfterFunc(*monitorDuration, bus.Close)
 		n := 0
 		mon.Start()
 		eventbustest.Expect(tw, func(event *ChangeDelta) (bool, error) {

+ 14 - 0
util/eventbus/eventbustest/doc.go

@@ -39,6 +39,20 @@
 // checks that the stream contains exactly the given events in the given order,
 // and no others.
 //
+// To test for the absence of events, use [ExpectExactly] without any
+// expected events, along side [testing/synctest] to avoid waiting for timers
+// to ensure that no events are produced. This will look like:
+//
+//	synctest.Test(t, func(t *testing.T) {
+//		bus := eventbustest.NewBus(t)
+//		tw := eventbustest.NewWatcher(t, bus)
+//		somethingThatShouldNotEmitsSomeEvent()
+//		synctest.Wait()
+//		if err := eventbustest.ExpectExactly(tw); err != nil {
+//			t.Errorf("Expected no events or errors, got %v", err)
+//		}
+//	})
+//
 // See the [usage examples].
 //
 // [usage examples]: https://github.com/tailscale/tailscale/blob/main/util/eventbus/eventbustest/examples_test.go

+ 19 - 16
util/eventbus/eventbustest/eventbustest.go

@@ -27,13 +27,9 @@ func NewBus(t testing.TB) *eventbus.Bus {
 // [Expect] and [ExpectExactly], to verify that the desired events were captured.
 func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher {
 	tw := &Watcher{
-		mon:     bus.Debugger().WatchBus(),
-		TimeOut: 5 * time.Second,
-		chDone:  make(chan bool, 1),
-		events:  make(chan any, 100),
-	}
-	if deadline, ok := t.Deadline(); ok {
-		tw.TimeOut = deadline.Sub(time.Now())
+		mon:    bus.Debugger().WatchBus(),
+		chDone: make(chan bool, 1),
+		events: make(chan any, 100),
 	}
 	t.Cleanup(tw.done)
 	go tw.watch()
@@ -41,16 +37,15 @@ func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher {
 }
 
 // Watcher monitors and holds events for test expectations.
+// The Watcher works with [synctest], and some scenarios does require the use of
+// [synctest]. This is amongst others true if you are testing for the absence of
+// events.
+//
+// For usage examples, see the documentation in the top of the package.
 type Watcher struct {
 	mon    *eventbus.Subscriber[eventbus.RoutedEvent]
 	events chan any
 	chDone chan bool
-	// TimeOut defines when the Expect* functions should stop looking for events
-	// coming from the Watcher. The value is set by [NewWatcher] and defaults to
-	// the deadline passed in by [testing.T]. If looking to verify the absence
-	// of an event, the TimeOut can be set to a lower value after creating the
-	// Watcher.
-	TimeOut time.Duration
 }
 
 // Type is a helper representing the expectation to see an event of type T, without
@@ -103,7 +98,8 @@ func Expect(tw *Watcher, filters ...any) error {
 			} else if ok {
 				head++
 			}
-		case <-time.After(tw.TimeOut):
+		// Use synctest when you want an error here.
+		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
 			return fmt.Errorf(
 				"timed out waiting for event, saw %d events, %d was expected",
 				eventCount, len(filters))
@@ -118,12 +114,16 @@ func Expect(tw *Watcher, filters ...any) error {
 // in a given order, returning an error if the events does not match the given list
 // exactly. The given events are represented by a function as described in
 // [Expect]. Use [Expect] if other events are allowed.
+//
+// If you are expecting ExpectExactly to fail because of a missing event, or if
+// you are testing for the absence of events, call [synctest.Wait] after
+// actions that would publish an event, but before calling ExpectExactly.
 func ExpectExactly(tw *Watcher, filters ...any) error {
 	if len(filters) == 0 {
 		select {
 		case event := <-tw.events:
 			return fmt.Errorf("saw event type %s, expected none", reflect.TypeOf(event))
-		case <-time.After(tw.TimeOut):
+		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
 			return nil
 		}
 	}
@@ -146,7 +146,7 @@ func ExpectExactly(tw *Watcher, filters ...any) error {
 				return fmt.Errorf(
 					"expected test ok for type %s, at index %d", argType, pos)
 			}
-		case <-time.After(tw.TimeOut):
+		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
 			return fmt.Errorf(
 				"timed out waiting for event, saw %d events, %d was expected",
 				eventCount, len(filters))
@@ -162,6 +162,9 @@ func (tw *Watcher) watch() {
 		select {
 		case event := <-tw.mon.Events():
 			tw.events <- event.Event
+		case <-tw.mon.Done():
+			tw.done()
+			return
 		case <-tw.chDone:
 			tw.mon.Close()
 			return

+ 72 - 78
util/eventbus/eventbustest/eventbustest_test.go

@@ -8,7 +8,7 @@ import (
 	"fmt"
 	"strings"
 	"testing"
-	"time"
+	"testing/synctest"
 
 	"tailscale.com/util/eventbus"
 	"tailscale.com/util/eventbus/eventbustest"
@@ -110,37 +110,35 @@ func TestExpectFilter(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			bus := eventbustest.NewBus(t)
-			t.Cleanup(bus.Close)
+			synctest.Test(t, func(t *testing.T) {
+				bus := eventbustest.NewBus(t)
 
-			if *doDebug {
-				eventbustest.LogAllEvents(t, bus)
-			}
-			tw := eventbustest.NewWatcher(t, bus)
+				if *doDebug {
+					eventbustest.LogAllEvents(t, bus)
+				}
+				tw := eventbustest.NewWatcher(t, bus)
 
-			// TODO(cmol): When synctest is out of experimental, use that instead:
-			// https://go.dev/blog/synctest
-			tw.TimeOut = 10 * time.Millisecond
+				client := bus.Client("testClient")
+				updater := eventbus.Publish[EventFoo](client)
 
-			client := bus.Client("testClient")
-			defer client.Close()
-			updater := eventbus.Publish[EventFoo](client)
+				for _, i := range tt.events {
+					updater.Publish(EventFoo{i})
+				}
 
-			for _, i := range tt.events {
-				updater.Publish(EventFoo{i})
-			}
+				synctest.Wait()
 
-			if err := eventbustest.Expect(tw, tt.expectFunc); err != nil {
-				if tt.wantErr == "" {
-					t.Errorf("Expect[EventFoo]: unexpected error: %v", err)
-				} else if !strings.Contains(err.Error(), tt.wantErr) {
-					t.Errorf("Expect[EventFoo]: err = %v, want %q", err, tt.wantErr)
-				} else {
-					t.Logf("Got expected error: %v (OK)", err)
+				if err := eventbustest.Expect(tw, tt.expectFunc); err != nil {
+					if tt.wantErr == "" {
+						t.Errorf("Expect[EventFoo]: unexpected error: %v", err)
+					} else if !strings.Contains(err.Error(), tt.wantErr) {
+						t.Errorf("Expect[EventFoo]: err = %v, want %q", err, tt.wantErr)
+					} else {
+						t.Logf("Got expected error: %v (OK)", err)
+					}
+				} else if tt.wantErr != "" {
+					t.Errorf("Expect[EventFoo]: unexpectedly succeeded, want error %q", tt.wantErr)
 				}
-			} else if tt.wantErr != "" {
-				t.Errorf("Expect[EventFoo]: unexpectedly succeeded, want error %q", tt.wantErr)
-			}
+			})
 		})
 	}
 }
@@ -244,37 +242,35 @@ func TestExpectEvents(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			bus := eventbustest.NewBus(t)
-			t.Cleanup(bus.Close)
+			synctest.Test(t, func(t *testing.T) {
+				bus := eventbustest.NewBus(t)
 
-			tw := eventbustest.NewWatcher(t, bus)
-			// TODO(cmol): When synctest is out of experimental, use that instead:
-			// https://go.dev/blog/synctest
-			tw.TimeOut = 100 * time.Millisecond
+				tw := eventbustest.NewWatcher(t, bus)
 
-			client := bus.Client("testClient")
-			defer client.Close()
-			updaterFoo := eventbus.Publish[EventFoo](client)
-			updaterBar := eventbus.Publish[EventBar](client)
-			updaterBaz := eventbus.Publish[EventBaz](client)
+				client := bus.Client("testClient")
+				updaterFoo := eventbus.Publish[EventFoo](client)
+				updaterBar := eventbus.Publish[EventBar](client)
+				updaterBaz := eventbus.Publish[EventBaz](client)
 
-			for _, ev := range tt.events {
-				switch ev.(type) {
-				case EventFoo:
-					evCast := ev.(EventFoo)
-					updaterFoo.Publish(evCast)
-				case EventBar:
-					evCast := ev.(EventBar)
-					updaterBar.Publish(evCast)
-				case EventBaz:
-					evCast := ev.(EventBaz)
-					updaterBaz.Publish(evCast)
+				for _, ev := range tt.events {
+					switch ev := ev.(type) {
+					case EventFoo:
+						evCast := ev
+						updaterFoo.Publish(evCast)
+					case EventBar:
+						evCast := ev
+						updaterBar.Publish(evCast)
+					case EventBaz:
+						evCast := ev
+						updaterBaz.Publish(evCast)
+					}
 				}
-			}
 
-			if err := eventbustest.Expect(tw, tt.expectEvents...); (err != nil) != tt.wantErr {
-				t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr)
-			}
+				synctest.Wait()
+				if err := eventbustest.Expect(tw, tt.expectEvents...); (err != nil) != tt.wantErr {
+					t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr)
+				}
+			})
 		})
 	}
 }
@@ -378,37 +374,35 @@ func TestExpectExactlyEventsFilter(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			bus := eventbustest.NewBus(t)
-			t.Cleanup(bus.Close)
+			synctest.Test(t, func(t *testing.T) {
+				bus := eventbustest.NewBus(t)
 
-			tw := eventbustest.NewWatcher(t, bus)
-			// TODO(cmol): When synctest is out of experimental, use that instead:
-			// https://go.dev/blog/synctest
-			tw.TimeOut = 10 * time.Millisecond
+				tw := eventbustest.NewWatcher(t, bus)
 
-			client := bus.Client("testClient")
-			defer client.Close()
-			updaterFoo := eventbus.Publish[EventFoo](client)
-			updaterBar := eventbus.Publish[EventBar](client)
-			updaterBaz := eventbus.Publish[EventBaz](client)
+				client := bus.Client("testClient")
+				updaterFoo := eventbus.Publish[EventFoo](client)
+				updaterBar := eventbus.Publish[EventBar](client)
+				updaterBaz := eventbus.Publish[EventBaz](client)
 
-			for _, ev := range tt.events {
-				switch ev.(type) {
-				case EventFoo:
-					evCast := ev.(EventFoo)
-					updaterFoo.Publish(evCast)
-				case EventBar:
-					evCast := ev.(EventBar)
-					updaterBar.Publish(evCast)
-				case EventBaz:
-					evCast := ev.(EventBaz)
-					updaterBaz.Publish(evCast)
+				for _, ev := range tt.events {
+					switch ev := ev.(type) {
+					case EventFoo:
+						evCast := ev
+						updaterFoo.Publish(evCast)
+					case EventBar:
+						evCast := ev
+						updaterBar.Publish(evCast)
+					case EventBaz:
+						evCast := ev
+						updaterBaz.Publish(evCast)
+					}
 				}
-			}
 
-			if err := eventbustest.ExpectExactly(tw, tt.expectEvents...); (err != nil) != tt.wantErr {
-				t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr)
-			}
+				synctest.Wait()
+				if err := eventbustest.ExpectExactly(tw, tt.expectEvents...); (err != nil) != tt.wantErr {
+					t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr)
+				}
+			})
 		})
 	}
 }

+ 59 - 0
util/eventbus/eventbustest/examples_test.go

@@ -5,6 +5,8 @@ package eventbustest_test
 
 import (
 	"testing"
+	"testing/synctest"
+	"time"
 
 	"tailscale.com/util/eventbus"
 	"tailscale.com/util/eventbus/eventbustest"
@@ -199,3 +201,60 @@ func TestExample_ExpectExactly_WithMultipleFunctions(t *testing.T) {
 	// Output:
 	// expected event type eventbustest.eventOfCuriosity, saw eventbustest.eventOfNoConcern, at index 1
 }
+
+func TestExample_ExpectExactly_NoEvents(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		bus := eventbustest.NewBus(t)
+		tw := eventbustest.NewWatcher(t, bus)
+
+		go func() {
+			// Do some work that does not produce an event
+			time.Sleep(10 * time.Second)
+			t.Log("Not producing events")
+		}()
+
+		// Wait for all other routines to be stale before continuing to ensure that
+		// there is nothing running that would produce an event at a later time.
+		synctest.Wait()
+
+		if err := eventbustest.ExpectExactly(tw); err != nil {
+			t.Error(err.Error())
+		} else {
+			t.Log("OK")
+		}
+		// Output:
+		// OK
+	})
+}
+
+func TestExample_ExpectExactly_OneEventExpectingTwo(t *testing.T) {
+	synctest.Test(t, func(t *testing.T) {
+		type eventOfInterest struct{}
+
+		bus := eventbustest.NewBus(t)
+		tw := eventbustest.NewWatcher(t, bus)
+		client := bus.Client("testClient")
+		updater := eventbus.Publish[eventOfInterest](client)
+
+		go func() {
+			// Do some work that does not produce an event
+			time.Sleep(10 * time.Second)
+			updater.Publish(eventOfInterest{})
+		}()
+
+		// Wait for all other routines to be stale before continuing to ensure that
+		// there is nothing running that would produce an event at a later time.
+		synctest.Wait()
+
+		if err := eventbustest.ExpectExactly(tw,
+			eventbustest.Type[eventOfInterest](),
+			eventbustest.Type[eventOfInterest](),
+		); err != nil {
+			t.Log(err.Error())
+		} else {
+			t.Log("OK")
+		}
+		// Output:
+		// timed out waiting for event, saw 1 events, 2 was expected
+	})
+}