2
0
Эх сурвалжийг харах

wgengine/magicsock: make test pass on Windows and without firewall dialog box

Updates #50
Brad Fitzpatrick 5 жил өмнө
parent
commit
fd2a30cd32

+ 13 - 4
net/netcheck/netcheck.go

@@ -156,6 +156,11 @@ type Client struct {
 	// GetSTUNConn6 is like GetSTUNConn4, but for IPv6.
 	// GetSTUNConn6 is like GetSTUNConn4, but for IPv6.
 	GetSTUNConn6 func() STUNConn
 	GetSTUNConn6 func() STUNConn
 
 
+	// SkipExternalNetwork controls whether the client should not try
+	// to reach things other than localhost. This is set to true
+	// in tests to avoid probing the local LAN's router, etc.
+	SkipExternalNetwork bool
+
 	mu       sync.Mutex            // guards following
 	mu       sync.Mutex            // guards following
 	nextFull bool                  // do a full region scan, even if last != nil
 	nextFull bool                  // do a full region scan, even if last != nil
 	prev     map[time.Time]*Report // some previous reports
 	prev     map[time.Time]*Report // some previous reports
@@ -831,8 +836,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
 	}
 	}
 	defer rs.pc4Hair.Close()
 	defer rs.pc4Hair.Close()
 
 
-	rs.waitPortMap.Add(1)
-	go rs.probePortMapServices()
+	if !c.SkipExternalNetwork {
+		rs.waitPortMap.Add(1)
+		go rs.probePortMapServices()
+	}
 
 
 	// At least the Apple Airport Extreme doesn't allow hairpin
 	// At least the Apple Airport Extreme doesn't allow hairpin
 	// sends from a private socket until it's seen traffic from
 	// sends from a private socket until it's seen traffic from
@@ -902,8 +909,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
 
 
 	rs.waitHairCheck(ctx)
 	rs.waitHairCheck(ctx)
 	c.vlogf("hairCheck done")
 	c.vlogf("hairCheck done")
-	rs.waitPortMap.Wait()
-	c.vlogf("portMap done")
+	if !c.SkipExternalNetwork {
+		rs.waitPortMap.Wait()
+		c.vlogf("portMap done")
+	}
 	rs.stopTimers()
 	rs.stopTimers()
 
 
 	// Try HTTPS latency check if all STUN probes failed due to UDP presumably being blocked.
 	// Try HTTPS latency check if all STUN probes failed due to UDP presumably being blocked.

+ 12 - 1
tstest/log.go

@@ -5,6 +5,8 @@
 package tstest
 package tstest
 
 
 import (
 import (
+	"bytes"
+	"fmt"
 	"log"
 	"log"
 	"os"
 	"os"
 	"sync"
 	"sync"
@@ -35,7 +37,16 @@ func UnfixLogs(t *testing.T) {
 type panicLogWriter struct{}
 type panicLogWriter struct{}
 
 
 func (panicLogWriter) Write(b []byte) (int, error) {
 func (panicLogWriter) Write(b []byte) (int, error) {
-	panic("please use tailscale.com/logger.Logf instead of the log package")
+	// Allow certain phrases for now, in the interest of getting
+	// CI working on Windows and not having to refactor all the
+	// interfaces.GetState & tshttpproxy code to allow pushing
+	// down a Logger yet. TODO(bradfitz): do that refactoring once
+	// 1.2.0 is out.
+	if bytes.Contains(b, []byte("tshttpproxy: ")) {
+		os.Stderr.Write(b)
+		return len(b), nil
+	}
+	panic(fmt.Sprintf("please use tailscale.com/logger.Logf instead of the log package (tried to log: %q)", b))
 }
 }
 
 
 // PanicOnLog modifies the standard library log package's default output to
 // PanicOnLog modifies the standard library log package's default output to

+ 14 - 4
wgengine/magicsock/magicsock.go

@@ -120,6 +120,7 @@ type Conn struct {
 	netChecker       *netcheck.Client
 	netChecker       *netcheck.Client
 	idleFunc         func() time.Duration   // nil means unknown
 	idleFunc         func() time.Duration   // nil means unknown
 	noteRecvActivity func(tailcfg.DiscoKey) // or nil, see Options.NoteRecvActivity
 	noteRecvActivity func(tailcfg.DiscoKey) // or nil, see Options.NoteRecvActivity
+	simulatedNetwork bool
 
 
 	// bufferedIPv4From and bufferedIPv4Packet are owned by
 	// bufferedIPv4From and bufferedIPv4Packet are owned by
 	// ReceiveIPv4, and used when both a DERP and IPv4 packet arrive
 	// ReceiveIPv4, and used when both a DERP and IPv4 packet arrive
@@ -312,6 +313,13 @@ type Options struct {
 	// Conn.CreateEndpoint, which acquires Conn.mu. As such, you
 	// Conn.CreateEndpoint, which acquires Conn.mu. As such, you
 	// should not hold Conn.mu while calling it.
 	// should not hold Conn.mu while calling it.
 	NoteRecvActivity func(tailcfg.DiscoKey)
 	NoteRecvActivity func(tailcfg.DiscoKey)
+
+	// SimulatedNetwork can be set true in tests to signal that
+	// the network is simulated and thus it's okay to bind on the
+	// unspecified address (which we'd normally avoid to avoid
+	// triggering macOS and Windows firwall dialog boxes during
+	// "go test").
+	SimulatedNetwork bool
 }
 }
 
 
 func (o *Options) logf() logger.Logf {
 func (o *Options) logf() logger.Logf {
@@ -369,6 +377,7 @@ func NewConn(opts Options) (*Conn, error) {
 	c.idleFunc = opts.IdleFunc
 	c.idleFunc = opts.IdleFunc
 	c.packetListener = opts.PacketListener
 	c.packetListener = opts.PacketListener
 	c.noteRecvActivity = opts.NoteRecvActivity
 	c.noteRecvActivity = opts.NoteRecvActivity
+	c.simulatedNetwork = opts.SimulatedNetwork
 
 
 	if err := c.initialBind(); err != nil {
 	if err := c.initialBind(); err != nil {
 		return nil, err
 		return nil, err
@@ -376,8 +385,9 @@ func NewConn(opts Options) (*Conn, error) {
 
 
 	c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
 	c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
 	c.netChecker = &netcheck.Client{
 	c.netChecker = &netcheck.Client{
-		Logf:         logger.WithPrefix(c.logf, "netcheck: "),
-		GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 },
+		Logf:                logger.WithPrefix(c.logf, "netcheck: "),
+		GetSTUNConn4:        func() netcheck.STUNConn { return c.pconn4 },
+		SkipExternalNetwork: inTest(),
 	}
 	}
 	if c.pconn6 != nil {
 	if c.pconn6 != nil {
 		c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 }
 		c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 }
@@ -2480,7 +2490,7 @@ func (c *Conn) listenPacket(ctx context.Context, network, addr string) (net.Pack
 
 
 func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
 func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
 	host := ""
 	host := ""
-	if inTest() {
+	if inTest() && !c.simulatedNetwork {
 		host = "127.0.0.1"
 		host = "127.0.0.1"
 	}
 	}
 	var pc net.PacketConn
 	var pc net.PacketConn
@@ -2510,7 +2520,7 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
 // It should be followed by a call to ReSTUN.
 // It should be followed by a call to ReSTUN.
 func (c *Conn) Rebind() {
 func (c *Conn) Rebind() {
 	host := ""
 	host := ""
-	if inTest() {
+	if inTest() && !c.simulatedNetwork {
 		host = "127.0.0.1"
 		host = "127.0.0.1"
 	}
 	}
 	listenCtx := context.Background() // unused without DNS name to resolve
 	listenCtx := context.Background() // unused without DNS name to resolve

+ 6 - 1
wgengine/magicsock/magicsock_test.go

@@ -46,6 +46,10 @@ import (
 	"tailscale.com/wgengine/tstun"
 	"tailscale.com/wgengine/tstun"
 )
 )
 
 
+func init() {
+	os.Setenv("IN_TS_TEST", "1")
+}
+
 // WaitReady waits until the magicsock is entirely initialized and connected
 // WaitReady waits until the magicsock is entirely initialized and connected
 // to its home DERP server. This is normally not necessary, since magicsock
 // to its home DERP server. This is normally not necessary, since magicsock
 // is intended to be entirely asynchronous, but it helps eliminate race
 // is intended to be entirely asynchronous, but it helps eliminate race
@@ -141,6 +145,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
 		EndpointsFunc: func(eps []string) {
 		EndpointsFunc: func(eps []string) {
 			epCh <- eps
 			epCh <- eps
 		},
 		},
+		SimulatedNetwork: l != nettype.Std{},
 	})
 	})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("constructing magicsock: %v", err)
 		t.Fatalf("constructing magicsock: %v", err)
@@ -374,7 +379,7 @@ collectEndpoints:
 
 
 func pickPort(t *testing.T) uint16 {
 func pickPort(t *testing.T) uint16 {
 	t.Helper()
 	t.Helper()
-	conn, err := net.ListenPacket("udp4", ":0")
+	conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}