ソースを参照

portlist: remove async functionality

This PR removes all async functionality from the portlist package
which may be a breaking change for non-tailscale importers. The only
importer within this codebase (LocalBackend) is already using the synchronous
API so no further action needed.

Fixes #8171

Signed-off-by: Marwan Sulaiman <[email protected]>
Marwan Sulaiman 2 年 前
コミット
f8f0b981ac
2 ファイル変更20 行追加194 行削除
  1. 9 117
      portlist/poller.go
  2. 11 77
      portlist/portlist_test.go

+ 9 - 117
portlist/poller.go

@@ -7,7 +7,6 @@
 package portlist
 
 import (
-	"context"
 	"errors"
 	"fmt"
 	"runtime"
@@ -19,7 +18,8 @@ import (
 )
 
 var (
-	pollInterval         = 5 * time.Second // default; changed by some OS-specific init funcs
+	newOSImpl            func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl.
+	pollInterval         = 5 * time.Second                  // default; changed by some OS-specific init funcs
 	debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST")
 )
 
@@ -37,8 +37,6 @@ type Poller struct {
 	// This field should only be changed before calling Run.
 	IncludeLocalhost bool
 
-	c chan List // unbuffered
-
 	// os, if non-nil, is an OS-specific implementation of the portlist getting
 	// code. When non-nil, it's responsible for getting the complete list of
 	// cached ports complete with the process name. That is, when set,
@@ -49,12 +47,6 @@ type Poller struct {
 	initOnce sync.Once // guards init of os
 	initErr  error
 
-	// closeCtx is the context that's canceled on Close.
-	closeCtx       context.Context
-	closeCtxCancel context.CancelFunc
-
-	runDone chan struct{} // closed when Run completes
-
 	// scatch is memory for Poller.getList to reuse between calls.
 	scratch []Port
 
@@ -75,36 +67,6 @@ type osImpl interface {
 	AppendListeningPorts(base []Port) ([]Port, error)
 }
 
-// newOSImpl, if non-nil, constructs a new osImpl.
-var newOSImpl func(includeLocalhost bool) osImpl
-
-var (
-	errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS)
-	errDisabled      = errors.New("portlist disabled by envknob")
-)
-
-// NewPoller returns a new portlist Poller. It returns an error
-// if the portlist couldn't be obtained.
-func NewPoller() (*Poller, error) {
-	p := &Poller{
-		c:       make(chan List),
-		runDone: make(chan struct{}),
-	}
-	p.initOnce.Do(p.init)
-	if p.initErr != nil {
-		return nil, p.initErr
-	}
-	p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background())
-	// Do one initial poll synchronously so we can return an error
-	// early.
-	if pl, err := p.getList(); err != nil {
-		return nil, err
-	} else {
-		p.setPrev(pl)
-	}
-	return p, nil
-}
-
 func (p *Poller) setPrev(pl List) {
 	// Make a copy, as the pass in pl slice aliases pl.scratch and we don't want
 	// that to except to the caller.
@@ -114,22 +76,16 @@ func (p *Poller) setPrev(pl List) {
 // init initializes the Poller by ensuring it has an underlying
 // OS implementation and is not turned off by envknob.
 func (p *Poller) init() {
-	if debugDisablePortlist() {
-		p.initErr = errDisabled
-		return
-	}
-	if newOSImpl == nil {
-		p.initErr = errUnimplemented
-		return
+	switch {
+	case debugDisablePortlist():
+		p.initErr = errors.New("portlist disabled by envknob")
+	case newOSImpl == nil:
+		p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS)
+	default:
+		p.os = newOSImpl(p.IncludeLocalhost)
 	}
-	p.os = newOSImpl(p.IncludeLocalhost)
 }
 
-// Updates return the channel that receives port list updates.
-//
-// The channel is closed when the Poller is closed.
-func (p *Poller) Updates() <-chan List { return p.c }
-
 // Close closes the Poller.
 func (p *Poller) Close() error {
 	if p.initErr != nil {
@@ -138,25 +94,9 @@ func (p *Poller) Close() error {
 	if p.os == nil {
 		return nil
 	}
-	if p.closeCtxCancel != nil {
-		p.closeCtxCancel()
-		<-p.runDone
-	}
 	return p.os.Close()
 }
 
-// send sends pl to p.c and returns whether it was successfully sent.
-func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
-	select {
-	case p.c <- pl:
-		return true, nil
-	case <-ctx.Done():
-		return false, ctx.Err()
-	case <-p.closeCtx.Done():
-		return false, nil
-	}
-}
-
 // Poll returns the list of listening ports, if changed from
 // a previous call as indicated by the changed result.
 func (p *Poller) Poll() (ports []Port, changed bool, err error) {
@@ -175,55 +115,7 @@ func (p *Poller) Poll() (ports []Port, changed bool, err error) {
 	return p.prev, true, nil
 }
 
-// Run runs the Poller periodically until either the context
-// is done, or the Close is called.
-//
-// Run may only be called once.
-func (p *Poller) Run(ctx context.Context) error {
-	tick := time.NewTicker(pollInterval)
-	defer tick.Stop()
-	return p.runWithTickChan(ctx, tick.C)
-}
-
-func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error {
-	defer close(p.runDone)
-	defer close(p.c)
-
-	// Send out the pre-generated initial value.
-	if sent, err := p.send(ctx, p.prev); !sent {
-		return err
-	}
-
-	for {
-		select {
-		case <-tickChan:
-			pl, err := p.getList()
-			if err != nil {
-				return err
-			}
-			if pl.equal(p.prev) {
-				continue
-			}
-			p.setPrev(pl)
-			if sent, err := p.send(ctx, p.prev); !sent {
-				return err
-			}
-		case <-ctx.Done():
-			return ctx.Err()
-		case <-p.closeCtx.Done():
-			return nil
-		}
-	}
-}
-
 func (p *Poller) getList() (List, error) {
-	// TODO(marwan): this method does not
-	// need to do any init logic. Update tests
-	// once async API is removed.
-	p.initOnce.Do(p.init)
-	if p.initErr == errDisabled {
-		return nil, nil
-	}
 	var err error
 	p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0])
 	return p.scratch, err

+ 11 - 77
portlist/portlist_test.go

@@ -4,11 +4,8 @@
 package portlist
 
 import (
-	"context"
 	"net"
-	"sync"
 	"testing"
-	"time"
 
 	"tailscale.com/tstest"
 )
@@ -17,14 +14,14 @@ func TestGetList(t *testing.T) {
 	tstest.ResourceCheck(t)
 
 	var p Poller
-	pl, err := p.getList()
+	pl, _, err := p.Poll()
 	if err != nil {
 		t.Fatal(err)
 	}
 	for i, p := range pl {
 		t.Logf("[%d] %+v", i, p)
 	}
-	t.Logf("As String: %v", pl.String())
+	t.Logf("As String: %s", List(pl))
 }
 
 func TestIgnoreLocallyBoundPorts(t *testing.T) {
@@ -38,7 +35,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
 	ta := ln.Addr().(*net.TCPAddr)
 	port := ta.Port
 	var p Poller
-	pl, err := p.getList()
+	pl, _, err := p.Poll()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -49,16 +46,16 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
 	}
 }
 
-func TestChangesOverTime(t *testing.T) {
+func TestPoller(t *testing.T) {
 	var p Poller
 	p.IncludeLocalhost = true
 	get := func(t *testing.T) []Port {
 		t.Helper()
-		s, err := p.getList()
+		s, _, err := p.Poll()
 		if err != nil {
 			t.Fatal(err)
 		}
-		return append([]Port(nil), s...)
+		return s
 	}
 
 	p1 := get(t)
@@ -192,74 +189,6 @@ func TestClose(t *testing.T) {
 	}
 }
 
-func TestPoller(t *testing.T) {
-	p, err := NewPoller()
-	if err != nil {
-		t.Skipf("not running test: %v", err)
-	}
-	t.Cleanup(func() {
-		if err := p.Close(); err != nil {
-			t.Errorf("error closing poller in test: %v", err)
-		}
-	})
-
-	var wg sync.WaitGroup
-	wg.Add(2)
-
-	gotUpdate := make(chan bool, 16)
-
-	go func() {
-		defer wg.Done()
-		for pl := range p.Updates() {
-			// Look at all the pl slice memory to maximize
-			// chance of race detector seeing violations.
-			for _, v := range pl {
-				if v == (Port{}) {
-					// Force use
-					panic("empty port")
-				}
-			}
-			select {
-			case gotUpdate <- true:
-			default:
-			}
-		}
-	}()
-
-	tick := make(chan time.Time, 16)
-	go func() {
-		defer wg.Done()
-		if err := p.runWithTickChan(context.Background(), tick); err != nil {
-			t.Error("runWithTickChan:", err)
-		}
-	}()
-	for i := 0; i < 10; i++ {
-		ln, err := net.Listen("tcp", ":0")
-		if err != nil {
-			t.Fatal(err)
-		}
-		defer ln.Close()
-		tick <- time.Time{}
-
-		select {
-		case <-gotUpdate:
-		case <-time.After(5 * time.Second):
-			t.Fatal("timed out waiting for update")
-		}
-	}
-
-	// And a bunch of ticks without waiting for updates,
-	// to make race tests more likely to fail, if any present.
-	for i := 0; i < 10; i++ {
-		tick <- time.Time{}
-	}
-
-	if err := p.Close(); err != nil {
-		t.Fatal(err)
-	}
-	wg.Wait()
-}
-
 func BenchmarkGetList(b *testing.B) {
 	benchmarkGetList(b, false)
 }
@@ -271,6 +200,11 @@ func BenchmarkGetListIncremental(b *testing.B) {
 func benchmarkGetList(b *testing.B, incremental bool) {
 	b.ReportAllocs()
 	var p Poller
+	p.init()
+	if p.initErr != nil {
+		b.Skip(p.initErr)
+	}
+	b.Cleanup(func() { p.Close() })
 	for i := 0; i < b.N; i++ {
 		pl, err := p.getList()
 		if err != nil {