Procházet zdrojové kódy

portlist: add Poller.IncludeLocalhost option

This PR parameterizes receiving loopback updates from the portlist package.
Callers can now include services bound to localhost if they want.
Note that this option is off by default still.

Fixes #8171

Signed-off-by: Marwan Sulaiman <[email protected]>
Marwan Sulaiman před 2 roky
rodič
revize
e32e5c0d0c

+ 3 - 3
portlist/netstat.go

@@ -67,7 +67,7 @@ type nothing struct{}
 // Unfortunately, options to filter by proto or state are non-portable,
 // so we'll filter for ourselves.
 // Nowadays, though, we only use it for macOS as of 2022-11-04.
-func appendParsePortsNetstat(base []Port, br *bufio.Reader) ([]Port, error) {
+func appendParsePortsNetstat(base []Port, br *bufio.Reader, includeLocalhost bool) ([]Port, error) {
 	ret := base
 	var fieldBuf [10]mem.RO
 	for {
@@ -99,7 +99,7 @@ func appendParsePortsNetstat(base []Port, br *bufio.Reader) ([]Port, error) {
 				// not interested in non-listener sockets
 				continue
 			}
-			if isLoopbackAddr(laddr) {
+			if !includeLocalhost && isLoopbackAddr(laddr) {
 				// not interested in loopback-bound listeners
 				continue
 			}
@@ -110,7 +110,7 @@ func appendParsePortsNetstat(base []Port, br *bufio.Reader) ([]Port, error) {
 			proto = "udp"
 			laddr = cols[len(cols)-2]
 			raddr = cols[len(cols)-1]
-			if isLoopbackAddr(laddr) {
+			if !includeLocalhost && isLoopbackAddr(laddr) {
 				// not interested in loopback-bound listeners
 				continue
 			}

+ 36 - 25
portlist/netstat_test.go

@@ -8,6 +8,7 @@ package portlist
 import (
 	"bufio"
 	"encoding/json"
+	"fmt"
 	"strings"
 	"testing"
 
@@ -52,30 +53,40 @@ udp46      0      0  *.146                 *.*
 `
 
 func TestParsePortsNetstat(t *testing.T) {
-	want := List{
-		Port{"tcp", 23, ""},
-		Port{"tcp", 24, ""},
-		Port{"udp", 104, ""},
-		Port{"udp", 106, ""},
-		Port{"udp", 146, ""},
-		Port{"tcp", 8185, ""}, // but not 8186, 8187, 8188 on localhost
-	}
-
-	pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)))
-	if err != nil {
-		t.Fatal(err)
-	}
-	pl = sortAndDedup(pl)
-	jgot, _ := json.MarshalIndent(pl, "", "\t")
-	jwant, _ := json.MarshalIndent(want, "", "\t")
-	if len(pl) != len(want) {
-		t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant)
-	}
-	for i := range pl {
-		if pl[i] != want[i] {
-			t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n",
-				i, pl[i], want[i])
-			t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant)
-		}
+	for _, loopBack := range [...]bool{false, true} {
+		t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) {
+			want := List{
+				{"tcp", 23, "", 0},
+				{"tcp", 24, "", 0},
+				{"udp", 104, "", 0},
+				{"udp", 106, "", 0},
+				{"udp", 146, "", 0},
+				{"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false
+			}
+			if loopBack {
+				want = append(want,
+					Port{"tcp", 8186, "", 0},
+					Port{"tcp", 8187, "", 0},
+					Port{"tcp", 8188, "", 0},
+				)
+			}
+			pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack)
+			if err != nil {
+				t.Fatal(err)
+			}
+			pl = sortAndDedup(pl)
+			jgot, _ := json.MarshalIndent(pl, "", "\t")
+			jwant, _ := json.MarshalIndent(want, "", "\t")
+			if len(pl) != len(want) {
+				t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant)
+			}
+			for i := range pl {
+				if pl[i] != want[i] {
+					t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n",
+						i, pl[i], want[i])
+					t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant)
+				}
+			}
+		})
 	}
 }

+ 7 - 2
portlist/poller.go

@@ -24,6 +24,11 @@ var debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST")
 // Poller scans the systems for listening ports periodically and sends
 // the results to C.
 type Poller struct {
+	// IncludeLocalhost controls whether services bound to localhost are included.
+	//
+	// This field should only be changed before calling Run.
+	IncludeLocalhost bool
+
 	c chan List // unbuffered
 
 	// os, if non-nil, is an OS-specific implementation of the portlist getting
@@ -62,7 +67,7 @@ type osImpl interface {
 }
 
 // newOSImpl, if non-nil, constructs a new osImpl.
-var newOSImpl func() osImpl
+var newOSImpl func(includeLocalhost bool) osImpl
 
 var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS)
 
@@ -100,7 +105,7 @@ func (p *Poller) setPrev(pl List) {
 
 func (p *Poller) initOSField() {
 	if newOSImpl != nil {
-		p.os = newOSImpl()
+		p.os = newOSImpl(p.IncludeLocalhost)
 	}
 }
 

+ 3 - 3
portlist/portlist.go

@@ -18,6 +18,7 @@ type Port struct {
 	Proto   string // "tcp" or "udp"
 	Port    uint16 // port number
 	Process string // optional process name, if found
+	Pid     int    // process id, if known
 }
 
 // List is a list of Ports.
@@ -69,12 +70,11 @@ func sortAndDedup(ps List) List {
 	out := ps[:0]
 	var last Port
 	for _, p := range ps {
-		protoPort := Port{Proto: p.Proto, Port: p.Port}
-		if last == protoPort {
+		if last.Proto == p.Proto && last.Port == p.Port {
 			continue
 		}
 		out = append(out, p)
-		last = protoPort
+		last = p
 	}
 	return out
 }

+ 14 - 8
portlist/portlist_linux.go

@@ -35,25 +35,28 @@ type linuxImpl struct {
 	procNetFiles    []*os.File // seeked to start & reused between calls
 	readlinkPathBuf []byte
 
-	known map[string]*portMeta // inode string => metadata
-	br    *bufio.Reader
+	known            map[string]*portMeta // inode string => metadata
+	br               *bufio.Reader
+	includeLocalhost bool
 }
 
 type portMeta struct {
 	port          Port
+	pid           int
 	keep          bool
 	needsProcName bool
 }
 
-func newLinuxImplBase() *linuxImpl {
+func newLinuxImplBase(includeLocalhost bool) *linuxImpl {
 	return &linuxImpl{
-		br:    bufio.NewReader(eofReader),
-		known: map[string]*portMeta{},
+		br:               bufio.NewReader(eofReader),
+		known:            map[string]*portMeta{},
+		includeLocalhost: includeLocalhost,
 	}
 }
 
-func newLinuxImpl() osImpl {
-	li := newLinuxImplBase()
+func newLinuxImpl(includeLocalhost bool) osImpl {
+	li := newLinuxImplBase(includeLocalhost)
 	for _, name := range []string{
 		"/proc/net/tcp",
 		"/proc/net/tcp6",
@@ -220,7 +223,7 @@ func (li *linuxImpl) parseProcNetFile(r *bufio.Reader, fileBase string) error {
 		// If a port is bound to localhost, ignore it.
 		// TODO: localhost is bigger than 1 IP, we need to ignore
 		// more things.
-		if mem.HasPrefix(local, mem.S(v4Localhost)) || mem.HasPrefix(local, mem.S(v6Localhost)) {
+		if !li.includeLocalhost && (mem.HasPrefix(local, mem.S(v4Localhost)) || mem.HasPrefix(local, mem.S(v6Localhost))) {
 			continue
 		}
 
@@ -315,6 +318,9 @@ func (li *linuxImpl) findProcessNames(need map[string]*portMeta) error {
 			}
 
 			argv := strings.Split(strings.TrimSuffix(string(bs), "\x00"), "\x00")
+			if p, err := mem.ParseInt(pid, 10, 0); err == nil {
+				pe.pid = int(p)
+			}
 			pe.port.Process = argvSubject(argv...)
 			pe.needsProcName = false
 			delete(need, string(targetBuf[:n]))

+ 2 - 2
portlist/portlist_linux_test.go

@@ -89,7 +89,7 @@ func TestParsePorts(t *testing.T) {
 			if tt.file != "" {
 				file = tt.file
 			}
-			li := newLinuxImplBase()
+			li := newLinuxImplBase(false)
 			err := li.parseProcNetFile(r, file)
 			if err != nil {
 				t.Fatal(err)
@@ -118,7 +118,7 @@ func BenchmarkParsePorts(b *testing.B) {
 		contents.WriteString("   3: 69050120005716BC64906EBE009ECD4D:D506 0047062600000000000000006E171268:01BB 01 00000000:00000000 02:0000009E 00000000  1000        0 151042856 2 0000000000000000 21 4 28 10 -1\n")
 	}
 
-	li := newLinuxImplBase()
+	li := newLinuxImplBase(false)
 
 	r := bytes.NewReader(contents.Bytes())
 	br := bufio.NewReader(&contents)

+ 14 - 6
portlist/portlist_macos.go

@@ -29,8 +29,9 @@ type macOSImpl struct {
 	known       map[protoPort]*portMeta // inode string => metadata
 	netstatPath string                  // lazily populated
 
-	br       *bufio.Reader // reused
-	portsBuf []Port
+	br               *bufio.Reader // reused
+	portsBuf         []Port
+	includeLocalhost bool
 }
 
 type protoPort struct {
@@ -43,10 +44,11 @@ type portMeta struct {
 	keep bool
 }
 
-func newMacOSImpl() osImpl {
+func newMacOSImpl(includeLocalhost bool) osImpl {
 	return &macOSImpl{
-		known: map[protoPort]*portMeta{},
-		br:    bufio.NewReader(bytes.NewReader(nil)),
+		known:            map[protoPort]*portMeta{},
+		br:               bufio.NewReader(bytes.NewReader(nil)),
+		includeLocalhost: includeLocalhost,
 	}
 }
 
@@ -119,7 +121,7 @@ func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) {
 	defer cmd.Process.Wait()
 	defer cmd.Process.Kill()
 
-	return appendParsePortsNetstat(base, im.br)
+	return appendParsePortsNetstat(base, im.br, im.includeLocalhost)
 }
 
 var lsofFailed atomic.Bool
@@ -170,6 +172,7 @@ func (im *macOSImpl) addProcesses() error {
 	im.br.Reset(outPipe)
 
 	var cmd, proto string
+	var pid int
 	for {
 		line, err := im.br.ReadBytes('\n')
 		if err != nil {
@@ -184,6 +187,10 @@ func (im *macOSImpl) addProcesses() error {
 			// starting a new process
 			cmd = ""
 			proto = ""
+			pid = 0
+			if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil {
+				pid = int(p)
+			}
 		case 'c':
 			cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs?
 		case 'P':
@@ -202,6 +209,7 @@ func (im *macOSImpl) addProcesses() error {
 			switch {
 			case m != nil:
 				m.port.Process = cmd
+				m.port.Pid = pid
 			default:
 				// ignore: processes and ports come and go
 			}

+ 2 - 11
portlist/portlist_test.go

@@ -5,9 +5,7 @@ package portlist
 
 import (
 	"context"
-	"flag"
 	"net"
-	"runtime"
 	"sync"
 	"testing"
 	"time"
@@ -51,16 +49,9 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
 	}
 }
 
-var flagRunUnspecTests = flag.Bool("run-unspec-tests",
-	runtime.GOOS == "linux", // other OSes have annoying firewall GUI confirmation dialogs
-	"run tests that require listening on the the unspecified address")
-
 func TestChangesOverTime(t *testing.T) {
-	if !*flagRunUnspecTests {
-		t.Skip("skipping test without --run-unspec-tests")
-	}
-
 	var p Poller
+	p.IncludeLocalhost = true
 	get := func(t *testing.T) []Port {
 		t.Helper()
 		s, err := p.getList()
@@ -71,7 +62,7 @@ func TestChangesOverTime(t *testing.T) {
 	}
 
 	p1 := get(t)
-	ln, err := net.Listen("tcp", ":0")
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Skipf("failed to bind: %v", err)
 	}

+ 7 - 4
portlist/portlist_windows.go

@@ -25,7 +25,8 @@ type famPort struct {
 }
 
 type windowsImpl struct {
-	known map[famPort]*portMeta // inode string => metadata
+	known            map[famPort]*portMeta // inode string => metadata
+	includeLocalhost bool
 }
 
 type portMeta struct {
@@ -33,9 +34,10 @@ type portMeta struct {
 	keep bool
 }
 
-func newWindowsImpl() osImpl {
+func newWindowsImpl(includeLocalhost bool) osImpl {
 	return &windowsImpl{
-		known: map[famPort]*portMeta{},
+		known:            map[famPort]*portMeta{},
+		includeLocalhost: includeLocalhost,
 	}
 }
 
@@ -58,7 +60,7 @@ func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) {
 		if e.State != "LISTEN" {
 			continue
 		}
-		if !e.Local.Addr().IsUnspecified() {
+		if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() {
 			continue
 		}
 		fp := famPort{
@@ -83,6 +85,7 @@ func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) {
 				Proto:   "tcp",
 				Port:    e.Local.Port(),
 				Process: process,
+				Pid:     e.Pid,
 			},
 		}
 		im.known[fp] = pm