Browse Source

net/netstat: start of new netstat package, with Windows for now

This will be used in a future change to do localhost connection
authentication. This lets us quickly map a localhost TCP connection to
a PID. (A future change will then map a pid to a user)

TODO: pull portlist's netstat code into this package. Then portlist
will be fast on Windows without requiring shelling out to netstat.exe.
Brad Fitzpatrick 5 years ago
parent
commit
f65eb4e5c1
4 changed files with 247 additions and 0 deletions
  1. 36 0
      net/netstat/netstat.go
  2. 11 0
      net/netstat/netstat_noimpl.go
  3. 22 0
      net/netstat/netstat_test.go
  4. 178 0
      net/netstat/netstat_windows.go

+ 36 - 0
net/netstat/netstat.go

@@ -0,0 +1,36 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package netstat returns the local machine's network connection table.
+package netstat
+
+import (
+	"errors"
+	"runtime"
+
+	"inet.af/netaddr"
+)
+
+var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS)
+
+type Entry struct {
+	Local, Remote netaddr.IPPort
+	Pid           int
+	State         string // TODO: type?
+}
+
+// Table contains local machine's TCP connection entries.
+//
+// Currently only TCP (IPv4 and IPv6) are included.
+type Table struct {
+	Entries []Entry
+}
+
+// Get returns the connection table.
+//
+// It returns ErrNotImplemented if the table is not available for the
+// current operating system.
+func Get() (*Table, error) {
+	return get()
+}

+ 11 - 0
net/netstat/netstat_noimpl.go

@@ -0,0 +1,11 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !windows
+
+package netstat
+
+func get() (*Table, error) {
+	return nil, ErrNotImplemented
+}

+ 22 - 0
net/netstat/netstat_test.go

@@ -0,0 +1,22 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netstat
+
+import (
+	"testing"
+)
+
+func TestGet(t *testing.T) {
+	nt, err := Get()
+	if err == ErrNotImplemented {
+		t.Skip("TODO: not implemented")
+	}
+	if err != nil {
+		t.Fatal(err)
+	}
+	for _, e := range nt.Entries {
+		t.Logf("Entry: %+v", e)
+	}
+}

+ 178 - 0
net/netstat/netstat_windows.go

@@ -0,0 +1,178 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package netstat returns the local machine's network connection table.
+package netstat
+
+import (
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"syscall"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+	"inet.af/netaddr"
+)
+
+// See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
+
+// TCP_TABLE_OWNER_PID_ALL means to include the PID info. The table type
+// we get back from Windows depends on AF_INET vs AF_INET6:
+// MIB_TCPTABLE_OWNER_PID for v4 or MIB_TCP6TABLE_OWNER_PID for v6.
+const tcpTableOwnerPidAll = 5
+
+var (
+	iphlpapi    = syscall.NewLazyDLL("iphlpapi.dll")
+	getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable")
+	// TODO: GetExtendedUdpTable also? if/when needed.
+)
+
+type _MIB_TCPROW_OWNER_PID struct {
+	state      uint32
+	localAddr  uint32
+	localPort  uint32
+	remoteAddr uint32
+	remotePort uint32
+	pid        uint32
+}
+
+type _MIB_TCP6ROW_OWNER_PID struct {
+	localAddr   [16]byte
+	localScope  uint32
+	localPort   uint32
+	remoteAddr  [16]byte
+	remoteScope uint32
+	remotePort  uint32
+	state       uint32
+	pid         uint32
+}
+
+func get() (*Table, error) {
+	t := new(Table)
+	if err := t.addEntries(windows.AF_INET); err != nil {
+		return nil, fmt.Errorf("failed to get IPv4 entries: %w", err)
+	}
+	if err := t.addEntries(windows.AF_INET6); err != nil {
+		return nil, fmt.Errorf("failed to get IPv6 entries: %w", err)
+	}
+	return t, nil
+}
+
+func (t *Table) addEntries(fam int) error {
+	var size uint32
+	var addr unsafe.Pointer
+	var buf []byte
+	for {
+		err, _, _ := getTCPTable.Call(
+			uintptr(addr),
+			uintptr(unsafe.Pointer(&size)),
+			1, // sorted
+			uintptr(fam),
+			tcpTableOwnerPidAll,
+			0, // reserved; "must be zero"
+		)
+		if err == 0 {
+			break
+		}
+		if err == uintptr(syscall.ERROR_INSUFFICIENT_BUFFER) {
+			const maxSize = 10 << 20
+			if size > maxSize || size < 4 {
+				return fmt.Errorf("unreasonable kernel-reported size %d", size)
+			}
+			buf = make([]byte, size)
+			addr = unsafe.Pointer(&buf[0])
+			continue
+		}
+		return syscall.Errno(err)
+	}
+	if len(buf) < int(size) {
+		return errors.New("unexpected size growth from system call")
+	}
+	buf = buf[:size]
+
+	numEntries := *(*uint32)(unsafe.Pointer(&buf[0]))
+	buf = buf[4:]
+
+	var recSize int
+	switch fam {
+	case windows.AF_INET:
+		recSize = 6 * 4
+	case windows.AF_INET6:
+		recSize = 6*4 + 16*2
+	}
+	dataLen := numEntries * uint32(recSize)
+	if uint32(len(buf)) > dataLen {
+		buf = buf[:dataLen]
+	}
+	for len(buf) >= recSize {
+		switch fam {
+		case windows.AF_INET:
+			row := (*_MIB_TCPROW_OWNER_PID)(unsafe.Pointer(&buf[0]))
+			t.Entries = append(t.Entries, Entry{
+				Local:  ipport4(row.localAddr, port(&row.localPort)),
+				Remote: ipport4(row.remoteAddr, port(&row.remotePort)),
+				Pid:    int(row.pid),
+				State:  state(row.state),
+			})
+		case windows.AF_INET6:
+			row := (*_MIB_TCP6ROW_OWNER_PID)(unsafe.Pointer(&buf[0]))
+			t.Entries = append(t.Entries, Entry{
+				Local:  ipport6(row.localAddr, row.localScope, port(&row.localPort)),
+				Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)),
+				Pid:    int(row.pid),
+				State:  state(row.state),
+			})
+		}
+		buf = buf[recSize:]
+	}
+	return nil
+}
+
+var states = []string{
+	"",
+	"CLOSED",
+	"LISTEN",
+	"SYN-SENT",
+	"SYN-RECEIVED",
+	"ESTABLISHED",
+	"FIN-WAIT-1",
+	"FIN-WAIT-2",
+	"CLOSE-WAIT",
+	"CLOSING",
+	"LAST-ACK",
+	"DELETE-TCB",
+}
+
+func state(v uint32) string {
+	if v < uint32(len(states)) {
+		return states[v]
+	}
+	return fmt.Sprintf("unknown-state-%d", v)
+}
+
+func ipport4(addr uint32, port uint16) netaddr.IPPort {
+	a4 := (*[4]byte)(unsafe.Pointer(&addr))
+	return netaddr.IPPort{
+		IP:   netaddr.IPv4(a4[0], a4[1], a4[2], a4[3]),
+		Port: port,
+	}
+}
+
+func ipport6(addr [16]byte, scope uint32, port uint16) netaddr.IPPort {
+	ip := netaddr.IPFrom16(addr)
+	if scope != 0 {
+		// TODO: something better here?
+		ip = ip.WithZone(fmt.Sprint(scope))
+	}
+	return netaddr.IPPort{
+		IP:   ip,
+		Port: port,
+	}
+}
+
+func port(v *uint32) uint16 {
+	p := (*[4]byte)(unsafe.Pointer(v))
+	return binary.BigEndian.Uint16(p[:2])
+}