浏览代码

cmd/tailscaled, wgengine: remove --fake, replace with netstack

And add a --socks5-server flag.

And fix a race in SOCKS5 replies where the response header was written
concurrently with the copy from the backend.

Co-authored with Naman Sood.

Updates #707
Updates #504

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 年之前
父节点
当前提交
38dc6fe758

+ 1 - 1
cmd/tailscaled/depaware.txt

@@ -96,7 +96,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
      💣 tailscale.com/net/netstat                                    from tailscale.com/ipn/ipnserver
         tailscale.com/net/packet                                     from tailscale.com/wgengine+
         tailscale.com/net/portmapper                                 from tailscale.com/net/netcheck+
-        tailscale.com/net/socks5                                     from tailscale.com/wgengine/netstack
+        tailscale.com/net/socks5                                     from tailscale.com/cmd/tailscaled
         tailscale.com/net/stun                                       from tailscale.com/net/netcheck+
         tailscale.com/net/tlsdial                                    from tailscale.com/control/controlclient+
         tailscale.com/net/tsaddr                                     from tailscale.com/ipn/ipnlocal+

+ 68 - 13
cmd/tailscaled/tailscaled.go

@@ -14,6 +14,7 @@ import (
 	"flag"
 	"fmt"
 	"log"
+	"net"
 	"net/http"
 	"net/http/pprof"
 	"os"
@@ -21,19 +22,23 @@ import (
 	"runtime"
 	"runtime/debug"
 	"strconv"
+	"sync"
 	"syscall"
 	"time"
 
 	"tailscale.com/ipn/ipnserver"
 	"tailscale.com/logpolicy"
+	"tailscale.com/net/socks5"
 	"tailscale.com/paths"
 	"tailscale.com/types/flagtype"
 	"tailscale.com/types/logger"
+	"tailscale.com/types/netmap"
 	"tailscale.com/version"
 	"tailscale.com/wgengine"
 	"tailscale.com/wgengine/magicsock"
 	"tailscale.com/wgengine/netstack"
 	"tailscale.com/wgengine/router"
+	"tailscale.com/wgengine/tstun"
 )
 
 // globalStateKey is the ipn.StateKey that tailscaled loads on
@@ -62,13 +67,13 @@ func defaultTunName() string {
 
 var args struct {
 	cleanup    bool
-	fake       bool
 	debug      string
 	tunname    string
 	port       uint16
 	statepath  string
 	socketpath string
 	verbose    int
+	socksAddr  string // listen address for SOCKS5 server
 }
 
 var (
@@ -94,9 +99,9 @@ func main() {
 	printVersion := false
 	flag.IntVar(&args.verbose, "verbose", 0, "log verbosity level; 0 is default, 1 or higher are increasingly verbose")
 	flag.BoolVar(&args.cleanup, "cleanup", false, "clean up system state and exit")
-	flag.BoolVar(&args.fake, "fake", false, "use userspace fake tunnel+routing instead of kernel TUN interface")
 	flag.StringVar(&args.debug, "debug", "", "listen address ([ip]:port) of optional debug server")
-	flag.StringVar(&args.tunname, "tun", defaultTunName(), "tunnel interface name")
+	flag.StringVar(&args.socksAddr, "socks5-server", "", `optional [ip]:port to run a SOCK5 server (e.g. "localhost:1080")`)
+	flag.StringVar(&args.tunname, "tun", defaultTunName(), `tunnel interface name; use "userspace-networking" (beta) to not use TUN`)
 	flag.Var(flagtype.PortValue(&args.port, magicsock.DefaultPort), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
 	flag.StringVar(&args.statepath, "state", paths.DefaultTailscaledStateFile(), "path of state file")
 	flag.StringVar(&args.socketpath, "socket", paths.DefaultTailscaledSocket(), "path of the service unix socket")
@@ -190,23 +195,73 @@ func run() error {
 		go runDebugServer(debugMux, args.debug)
 	}
 
-	var e wgengine.Engine
-	if args.fake {
-		var impl wgengine.FakeImplFactory
-		if args.tunname == "userspace-networking" {
-			impl = netstack.Create
+	var socksListener net.Listener
+	if args.socksAddr != "" {
+		var err error
+		socksListener, err = net.Listen("tcp", args.socksAddr)
+		if err != nil {
+			log.Fatalf("SOCKS5 listener: %v", err)
 		}
-		e, err = wgengine.NewFakeUserspaceEngine(logf, 0, impl)
+	}
+
+	conf := wgengine.Config{
+		ListenPort: args.port,
+	}
+	if args.tunname == "userspace-networking" {
+		conf.TUN = tstun.NewFakeTUN()
+		conf.RouterGen = router.NewFake
 	} else {
-		e, err = wgengine.NewUserspaceEngine(logf, wgengine.Config{
-			TUNName:    args.tunname,
-			ListenPort: args.port,
-		})
+		conf.TUNName = args.tunname
 	}
+
+	e, err := wgengine.NewUserspaceEngine(logf, conf)
 	if err != nil {
 		logf("wgengine.New: %v", err)
 		return err
 	}
+
+	var ns *netstack.Impl
+	if args.tunname == "userspace-networking" {
+		tunDev, magicConn := e.(wgengine.InternalsGetter).GetInternals()
+		ns, err = netstack.Create(logf, tunDev, e, magicConn)
+		if err != nil {
+			log.Fatalf("netstack.Create: %v", err)
+		}
+		if err := ns.Start(); err != nil {
+			log.Fatalf("failed to start netstack: %v", err)
+		}
+	}
+
+	if socksListener != nil {
+		srv := &socks5.Server{
+			Logf: logger.WithPrefix(logf, "socks5: "),
+		}
+		if args.tunname == "userspace-networking" {
+			srv.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
+				return ns.DialContextTCP(ctx, addr)
+			}
+		} else {
+			var mu sync.Mutex
+			var dns netstack.DNSMap
+			e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) {
+				mu.Lock()
+				defer mu.Unlock()
+				dns = netstack.DNSMapFromNetworkMap(nm)
+			})
+			srv.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
+				ipp, err := dns.Resolve(ctx, addr)
+				if err != nil {
+					return nil, err
+				}
+				var d net.Dialer
+				return d.DialContext(ctx, network, ipp.String())
+			}
+		}
+		go func() {
+			log.Fatalf("SOCKS5 server exited: %v", srv.Serve(socksListener))
+		}()
+	}
+
 	e = wgengine.NewWatchdog(e)
 
 	ctx, cancel := context.WithCancel(context.Background())

+ 1 - 1
ipn/ipnlocal/loglines_test.go

@@ -41,7 +41,7 @@ func TestLocalLogLines(t *testing.T) {
 
 	// set up a LocalBackend, super bare bones. No functional data.
 	store := &ipn.MemoryStore{}
-	e, err := wgengine.NewFakeUserspaceEngine(logListen.Logf, 0, nil)
+	e, err := wgengine.NewFakeUserspaceEngine(logListen.Logf, 0)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 1 - 1
ipn/ipnserver/server_test.go

@@ -56,7 +56,7 @@ func TestRunMultipleAccepts(t *testing.T) {
 		}
 	}
 
-	eng, err := wgengine.NewFakeUserspaceEngine(logf, 0, nil)
+	eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 20 - 12
net/socks5/socks5.go

@@ -108,7 +108,7 @@ func (s *Server) Serve(l net.Listener) error {
 			conn := &Conn{clientConn: c, srv: s}
 			err := conn.Run()
 			if err != nil {
-				s.logf("socks5: client connection failed: %v", err)
+				s.logf("client connection failed: %v", err)
 				conn.clientConn.Close()
 			}
 		}()
@@ -123,7 +123,6 @@ type Conn struct {
 
 	srv        *Server
 	clientConn net.Conn
-	serverConn net.Conn
 	request    *request
 }
 
@@ -153,11 +152,7 @@ func (c *Conn) handleRequest() error {
 		return fmt.Errorf("unsupported command %v", req.command)
 	}
 	c.request = req
-	return c.createReply()
-}
 
-func (c *Conn) createReply() error {
-	var err error
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
 	srv, err := c.srv.dial(
@@ -171,14 +166,12 @@ func (c *Conn) createReply() error {
 		c.clientConn.Write(buf)
 		return err
 	}
-	c.serverConn = srv
-	serverAddr, serverPortStr, err := net.SplitHostPort(c.serverConn.LocalAddr().String())
+	defer srv.Close()
+	serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
 	if err != nil {
 		return err
 	}
 	serverPort, _ := strconv.Atoi(serverPortStr)
-	go io.Copy(c.clientConn, c.serverConn)
-	go io.Copy(c.serverConn, c.clientConn)
 
 	var bindAddrType addrType
 	if ip := net.ParseIP(serverAddr); ip != nil {
@@ -190,7 +183,6 @@ func (c *Conn) createReply() error {
 	} else {
 		bindAddrType = domainName
 	}
-
 	res := &response{
 		reply:        success,
 		bindAddrType: bindAddrType,
@@ -203,7 +195,23 @@ func (c *Conn) createReply() error {
 		buf, _ = res.marshal()
 	}
 	c.clientConn.Write(buf)
-	return err
+
+	errc := make(chan error, 2)
+	go func() {
+		_, err := io.Copy(c.clientConn, srv)
+		if err != nil {
+			err = fmt.Errorf("from backend to client: %w", err)
+		}
+		errc <- err
+	}()
+	go func() {
+		_, err := io.Copy(srv, c.clientConn)
+		if err != nil {
+			err = fmt.Errorf("from client to backend: %w", err)
+		}
+		errc <- err
+	}()
+	return <-errc
 }
 
 // parseClientGreeting parses a request initiation packet

+ 30 - 36
wgengine/netstack/netstack.go

@@ -33,7 +33,6 @@ import (
 	"gvisor.dev/gvisor/pkg/waiter"
 	"inet.af/netaddr"
 	"tailscale.com/net/packet"
-	"tailscale.com/net/socks5"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/netmap"
 	"tailscale.com/util/dnsname"
@@ -55,13 +54,13 @@ type Impl struct {
 	logf    logger.Logf
 
 	mu  sync.Mutex
-	dns map[string]netaddr.IP // Magic DNS names (both base + FQDN) => first IP
+	dns DNSMap
 }
 
 const nicID = 1
 
 // Create creates and populates a new Impl.
-func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (wgengine.FakeImpl, error) {
+func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) {
 	if mc == nil {
 		return nil, errors.New("nil magicsock.Conn")
 	}
@@ -121,33 +120,40 @@ func (ns *Impl) Start() error {
 	ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
 	go ns.injectOutbound()
 	ns.tundev.PostFilterIn = ns.injectInbound
-	go ns.socks5Server()
-
 	return nil
 }
 
-func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
-	ns.mu.Lock()
-	defer ns.mu.Unlock()
-	ns.dns = make(map[string]netaddr.IP)
+// DNSMap maps MagicDNS names (both base + FQDN) to their first IP.
+// It should not be mutated once created.
+type DNSMap map[string]netaddr.IP
+
+func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap {
+	ret := make(DNSMap)
 	suffix := nm.MagicDNSSuffix()
 
 	if nm.Name != "" && len(nm.Addresses) > 0 {
 		ip := nm.Addresses[0].IP
-		ns.dns[strings.TrimRight(nm.Name, ".")] = ip
+		ret[strings.TrimRight(nm.Name, ".")] = ip
 		if dnsname.HasSuffix(nm.Name, suffix) {
-			ns.dns[dnsname.TrimSuffix(nm.Name, suffix)] = ip
+			ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip
 		}
 	}
 	for _, p := range nm.Peers {
 		if p.Name != "" && len(p.Addresses) > 0 {
 			ip := p.Addresses[0].IP
-			ns.dns[strings.TrimRight(p.Name, ".")] = ip
+			ret[strings.TrimRight(p.Name, ".")] = ip
 			if dnsname.HasSuffix(p.Name, suffix) {
-				ns.dns[dnsname.TrimSuffix(p.Name, suffix)] = ip
+				ret[dnsname.TrimSuffix(p.Name, suffix)] = ip
 			}
 		}
 	}
+	return ret
+}
+
+func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
+	ns.mu.Lock()
+	defer ns.mu.Unlock()
+	ns.dns = DNSMapFromNetworkMap(nm)
 }
 
 func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
@@ -198,8 +204,9 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
 	}
 }
 
-// resolve resolves addr into an IP:port.
-func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
+// Resolve resolves addr into an IP:port using first the MagicDNS contents
+// of m, else using the system resolver.
+func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
 	ipp, pippErr := netaddr.ParseIPPort(addr)
 	if pippErr == nil {
 		return ipp, nil
@@ -222,9 +229,7 @@ func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error
 	// Host is not an IP, so assume it's a DNS name.
 
 	// Try MagicDNS first, else otherwise a real DNS lookup.
-	ns.mu.Lock()
-	ip := ns.dns[host]
-	ns.mu.Unlock()
+	ip := m[host]
 	if !ip.IsZero() {
 		return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
 	}
@@ -242,8 +247,12 @@ func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error
 	return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
 }
 
-func (ns *Impl) dialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
-	remoteIPPort, err := ns.resolve(ctx, addr)
+func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
+	ns.mu.Lock()
+	dnsMap := ns.dns
+	ns.mu.Unlock()
+
+	remoteIPPort, err := dnsMap.Resolve(ctx, addr)
 	if err != nil {
 		return nil, err
 	}
@@ -342,7 +351,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address stri
 		}
 		cancel()
 	}()
-	server, err := ns.dialContextTCP(ctx, address)
+	server, err := ns.DialContextTCP(ctx, address)
 	if err != nil {
 		ns.logf("netstack: could not connect to server %s: %s", address, err)
 		return
@@ -361,21 +370,6 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address stri
 	ns.logf("[v2] netstack: forwarder connection to %s closed", address)
 }
 
-func (ns *Impl) socks5Server() {
-	ln, err := net.Listen("tcp", "localhost:1080")
-	if err != nil {
-		ns.logf("could not start SOCKS5 listener: %v", err)
-		return
-	}
-	srv := &socks5.Server{
-		Logf: ns.logf,
-		Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
-			return ns.dialContextTCP(ctx, addr)
-		},
-	}
-	ns.logf("SOCKS5 server exited: %v", srv.Serve(ln))
-}
-
 func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
 	ns.logf("[v2] UDP ForwarderRequest: %v", r)
 	var wq waiter.Queue

+ 17 - 37
wgengine/userspace.go

@@ -131,6 +131,15 @@ type userspaceEngine struct {
 	// Lock ordering: magicsock.Conn.mu, wgLock, then mu.
 }
 
+// InternalsGetter is implemented by Engines that can export their internals.
+type InternalsGetter interface {
+	GetInternals() (*tstun.TUN, *magicsock.Conn)
+}
+
+func (e *userspaceEngine) GetInternals() (*tstun.TUN, *magicsock.Conn) {
+	return e.tundev, e.magicConn
+}
+
 // RouterGen is the signature for a function that creates a
 // router.Router.
 type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
@@ -157,36 +166,18 @@ type Config struct {
 	// If zero, a port is automatically selected.
 	ListenPort uint16
 
-	// Fake determines whether this engine is running in fake mode,
-	// which disables such features as DNS configuration and unrestricted ICMP Echo responses.
+	// Fake determines whether this engine should automatically
+	// reply to ICMP pings.
 	Fake bool
-
-	// FakeImplFactory, if non-nil, creates a FakeImpl to use as a fake engine
-	// implementation. Two values are typical: nil, for a basic ping-only fake
-	// implementation, and netstack.Create, which creates a userspace network
-	// stack using gvisor's netstack. The desire to keep netstack out of some
-	// binaries is why the FakeImpl interface exists, so wgengine need not
-	// depend on gvisor.
-	FakeImplFactory FakeImplFactory
-}
-
-// FakeImpl is a fake or alternate version of Engine that can be started. See
-// Config.FakeImplFactory for details.
-type FakeImpl interface {
-	Start() error
 }
 
-// FakeImplFactory is the type of a function used to create FakeImpls.
-type FakeImplFactory func(logger.Logf, *tstun.TUN, Engine, *magicsock.Conn) (FakeImpl, error)
-
-func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, impl FakeImplFactory) (Engine, error) {
+func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
 	logf("Starting userspace wireguard engine (with fake TUN device)")
 	return NewUserspaceEngine(logf, Config{
-		TUN:             tstun.NewFakeTUN(),
-		RouterGen:       router.NewFake,
-		ListenPort:      listenPort,
-		Fake:            true,
-		FakeImplFactory: impl,
+		TUN:        tstun.NewFakeTUN(),
+		RouterGen:  router.NewFake,
+		ListenPort: listenPort,
+		Fake:       true,
 	})
 }
 
@@ -292,18 +283,7 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_
 
 	// Respond to all pings only in fake mode.
 	if conf.Fake {
-		if f := conf.FakeImplFactory; f != nil {
-			impl, err := f(logf, e.tundev, e, e.magicConn)
-			if err != nil {
-				return nil, err
-			}
-			if err := impl.Start(); err != nil {
-				return nil, err
-			}
-		} else {
-			// Respond to all pings only in fake mode.
-			e.tundev.PostFilterIn = echoRespondToAll
-		}
+		e.tundev.PostFilterIn = echoRespondToAll
 	}
 	e.tundev.PreFilterOut = e.handleLocalPackets
 

+ 1 - 1
wgengine/userspace_test.go

@@ -84,7 +84,7 @@ func TestNoteReceiveActivity(t *testing.T) {
 }
 
 func TestUserspaceEngineReconfig(t *testing.T) {
-	e, err := NewFakeUserspaceEngine(t.Logf, 0, nil)
+	e, err := NewFakeUserspaceEngine(t.Logf, 0)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
wgengine/watchdog_test.go

@@ -17,7 +17,7 @@ func TestWatchdog(t *testing.T) {
 
 	t.Run("default watchdog does not fire", func(t *testing.T) {
 		t.Parallel()
-		e, err := NewFakeUserspaceEngine(t.Logf, 0, nil)
+		e, err := NewFakeUserspaceEngine(t.Logf, 0)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -35,7 +35,7 @@ func TestWatchdog(t *testing.T) {
 
 	t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
 		t.Parallel()
-		e, err := NewFakeUserspaceEngine(t.Logf, 0, nil)
+		e, err := NewFakeUserspaceEngine(t.Logf, 0)
 		if err != nil {
 			t.Fatal(err)
 		}