Explorar o código

portlist: fix data race

Maisem spotted the bug. The initial getList call in NewPoller wasn't
making a clone (only the Run loop's getList calls).

Fixes #6314

Change-Id: I8ab8799fcccea8e799140340d0ff88a825bb6ff0
Co-authored-by: Maisem Ali <[email protected]>
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick %!s(int64=3) %!d(string=hai) anos
pai
achega
f81351fdef
Modificáronse 2 ficheiros con 85 adicións e 10 borrados
  1. 18 10
      portlist/poller.go
  2. 67 0
      portlist/portlist_test.go

+ 18 - 10
portlist/poller.go

@@ -14,6 +14,7 @@ import (
 	"sync"
 	"time"
 
+	"golang.org/x/exp/slices"
 	"tailscale.com/envknob"
 )
 
@@ -84,14 +85,20 @@ func NewPoller() (*Poller, error) {
 
 	// Do one initial poll synchronously so we can return an error
 	// early.
-	var err error
-	p.prev, err = p.getList()
-	if err != nil {
+	if pl, err := p.getList(); err != nil {
 		return nil, err
+	} else {
+		p.setPrev(pl)
 	}
 	return p, nil
 }
 
+func (p *Poller) setPrev(pl List) {
+	// Make a copy, as the pass in pl slice aliases pl.scratch and we don't want
+	// that to except to the caller.
+	p.prev = slices.Clone(pl)
+}
+
 func (p *Poller) initOSField() {
 	if newOSImpl != nil {
 		p.os = newOSImpl()
@@ -131,11 +138,14 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
 //
 // Run may only be called once.
 func (p *Poller) Run(ctx context.Context) error {
-	defer close(p.runDone)
-	defer close(p.c)
-
 	tick := time.NewTicker(pollInterval)
 	defer tick.Stop()
+	return p.runWithTickChan(ctx, tick.C)
+}
+
+func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error {
+	defer close(p.runDone)
+	defer close(p.c)
 
 	// Send out the pre-generated initial value.
 	if sent, err := p.send(ctx, p.prev); !sent {
@@ -144,7 +154,7 @@ func (p *Poller) Run(ctx context.Context) error {
 
 	for {
 		select {
-		case <-tick.C:
+		case <-tickChan:
 			pl, err := p.getList()
 			if err != nil {
 				return err
@@ -152,9 +162,7 @@ func (p *Poller) Run(ctx context.Context) error {
 			if pl.equal(p.prev) {
 				continue
 			}
-			// New value. Make a copy, as pl might alias pl.scratch
-			// and prev must not.
-			p.prev = append([]Port(nil), pl...)
+			p.setPrev(pl)
 			if sent, err := p.send(ctx, p.prev); !sent {
 				return err
 			}

+ 67 - 0
portlist/portlist_test.go

@@ -5,10 +5,13 @@
 package portlist
 
 import (
+	"context"
 	"flag"
 	"net"
 	"runtime"
+	"sync"
 	"testing"
+	"time"
 
 	"tailscale.com/tstest"
 )
@@ -182,6 +185,70 @@ func TestEqualLessThan(t *testing.T) {
 	}
 }
 
+func TestPoller(t *testing.T) {
+	p, err := NewPoller()
+	if err != nil {
+		t.Skipf("not running test: %v", err)
+	}
+	defer p.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	gotUpdate := make(chan bool, 16)
+
+	go func() {
+		defer wg.Done()
+		for pl := range p.Updates() {
+			// Look at all the pl slice memory to maximize
+			// chance of race detector seeing violations.
+			for _, v := range pl {
+				if v == (Port{}) {
+					// Force use
+					panic("empty port")
+				}
+			}
+			select {
+			case gotUpdate <- true:
+			default:
+			}
+		}
+	}()
+
+	tick := make(chan time.Time, 16)
+	go func() {
+		defer wg.Done()
+		if err := p.runWithTickChan(context.Background(), tick); err != nil {
+			t.Error("runWithTickChan:", err)
+		}
+	}()
+	for i := 0; i < 10; i++ {
+		ln, err := net.Listen("tcp", ":0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ln.Close()
+		tick <- time.Time{}
+
+		select {
+		case <-gotUpdate:
+		case <-time.After(5 * time.Second):
+			t.Fatal("timed out waiting for update")
+		}
+	}
+
+	// And a bunch of ticks without waiting for updates,
+	// to make race tests more likely to fail, if any present.
+	for i := 0; i < 10; i++ {
+		tick <- time.Time{}
+	}
+
+	if err := p.Close(); err != nil {
+		t.Fatal(err)
+	}
+	wg.Wait()
+}
+
 func BenchmarkGetList(b *testing.B) {
 	benchmarkGetList(b, false)
 }