Procházet zdrojové kódy

wgengine/magicsock: make portmapping async

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick před 4 roky
rodič
revize
92077ae78c

+ 1 - 1
cmd/tailscale/cli/netcheck.go

@@ -50,7 +50,7 @@ var netcheckArgs struct {
 func runNetcheck(ctx context.Context, args []string) error {
 	c := &netcheck.Client{
 		UDPBindAddr: os.Getenv("TS_DEBUG_NETCHECK_UDP_BIND"),
-		PortMapper:  portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: ")),
+		PortMapper:  portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: "), nil),
 	}
 	if netcheckArgs.verbose {
 		c.Logf = logger.WithPrefix(log.Printf, "netcheck: ")

+ 76 - 14
net/portmapper/portmapper.go

@@ -44,9 +44,15 @@ const trustServiceStillAvailableDuration = 10 * time.Minute
 type Client struct {
 	logf         logger.Logf
 	ipAndGateway func() (gw, ip netaddr.IP, ok bool)
+	onChange     func() // or nil
 
 	mu sync.Mutex // guards following, and all fields thereof
 
+	// runningCreate is whether we're currently working on creating
+	// a port mapping (whether GetCachedMappingOrStartCreatingOne kicked
+	// off a createMapping goroutine).
+	runningCreate bool
+
 	lastMyIP netaddr.IP
 	lastGW   netaddr.IP
 	closed   bool
@@ -68,18 +74,19 @@ type Client struct {
 func (c *Client) HaveMapping() bool {
 	c.mu.Lock()
 	defer c.mu.Unlock()
-	return c.pmpMapping != nil && c.pmpMapping.useUntil.After(time.Now())
+	return c.pmpMapping != nil && c.pmpMapping.goodUntil.After(time.Now())
 }
 
 // pmpMapping is an already-created PMP mapping.
 //
 // All fields are immutable once created.
 type pmpMapping struct {
-	gw       netaddr.IP
-	external netaddr.IPPort
-	internal netaddr.IPPort
-	useUntil time.Time // the mapping's lifetime minus renewal interval
-	epoch    uint32
+	gw         netaddr.IP
+	external   netaddr.IPPort
+	internal   netaddr.IPPort
+	renewAfter time.Time // the time at which we want to renew the mapping
+	goodUntil  time.Time // the mapping's total lifetime
+	epoch      uint32
 }
 
 // externalValid reports whether m.external is valid, with both its IP and Port populated.
@@ -99,10 +106,15 @@ func (m *pmpMapping) release() {
 }
 
 // NewClient returns a new portmapping client.
-func NewClient(logf logger.Logf) *Client {
+//
+// The optional onChange argument specifies a func to run in a new
+// goroutine whenever the port mapping status has changed. If nil,
+// it doesn't make a callback.
+func NewClient(logf logger.Logf, onChange func()) *Client {
 	return &Client{
 		logf:         logf,
 		ipAndGateway: interfaces.LikelyHomeRouterIP,
+		onChange:     onChange,
 	}
 }
 
@@ -221,8 +233,7 @@ func closeCloserOnContextDone(ctx context.Context, c io.Closer) (stop func()) {
 	return func() { close(stopWaitDone) }
 }
 
-// NoMappingError is returned by CreateOrGetMapping when no NAT
-// mapping could be returned.
+// NoMappingError is returned when no NAT mapping could be done.
 type NoMappingError struct {
 	err error
 }
@@ -241,12 +252,62 @@ var (
 	ErrGatewayNotFound       = errors.New("failed to look up gateway address")
 )
 
-// CreateOrGetMapping either creates a new mapping or returns a cached
+// GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any.
+// If there's not one, it starts up a background goroutine to create one.
+// If the background goroutine ends up creating one, the onChange hook registered with the
+// NewClient constructor (if any) will fire.
+func (c *Client) GetCachedMappingOrStartCreatingOne() (external netaddr.IPPort, ok bool) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	// Do we have an existing mapping that's valid?
+	now := time.Now()
+	if m := c.pmpMapping; m != nil {
+		if now.Before(m.goodUntil) {
+			if now.After(m.renewAfter) {
+				c.maybeStartMappingLocked()
+			}
+			return m.external, true
+		}
+	}
+
+	c.maybeStartMappingLocked()
+	return netaddr.IPPort{}, false
+}
+
+// maybeStartMappingLocked starts a createMapping goroutine up, if one isn't already running.
+//
+// c.mu must be held.
+func (c *Client) maybeStartMappingLocked() {
+	if !c.runningCreate {
+		c.runningCreate = true
+		go c.createMapping()
+	}
+}
+
+func (c *Client) createMapping() {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
+	defer func() {
+		c.mu.Lock()
+		defer c.mu.Unlock()
+		c.runningCreate = false
+	}()
+
+	if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil {
+		go c.onChange()
+	} else if err != nil && !IsNoMappingError(err) {
+		c.logf("createOrGetMapping: %v", err)
+	}
+}
+
+// createOrGetMapping either creates a new mapping or returns a cached
 // valid one.
 //
 // If no mapping is available, the error will be of type
 // NoMappingError; see IsNoMappingError.
-func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) {
+func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) {
 	gw, myIP, ok := c.gatewayAndSelfIP()
 	if !ok {
 		return netaddr.IPPort{}, NoMappingError{ErrGatewayNotFound}
@@ -266,7 +327,7 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor
 	// Do we have an existing mapping that's valid?
 	now := time.Now()
 	if m := c.pmpMapping; m != nil {
-		if now.Before(m.useUntil) {
+		if now.Before(m.renewAfter) {
 			defer c.mu.Unlock()
 			return m.external, nil
 		}
@@ -342,8 +403,9 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor
 			if pres.OpCode == pmpOpReply|pmpOpMapUDP {
 				m.external = m.external.WithPort(pres.ExternalPort)
 				d := time.Duration(pres.MappingValidSeconds) * time.Second
-				d /= 2 // renew in half the time
-				m.useUntil = time.Now().Add(d)
+				now := time.Now()
+				m.goodUntil = now.Add(d)
+				m.renewAfter = now.Add(d / 2) // renew in half the time
 				m.epoch = pres.SecondsSinceEpoch
 			}
 		}

+ 6 - 6
net/portmapper/portmapper_test.go

@@ -16,13 +16,13 @@ func TestCreateOrGetMapping(t *testing.T) {
 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
 		t.Skip("skipping test without HIT_NETWORK=1")
 	}
-	c := NewClient(t.Logf)
+	c := NewClient(t.Logf, nil)
 	c.SetLocalPort(1234)
 	for i := 0; i < 2; i++ {
 		if i > 0 {
 			time.Sleep(100 * time.Millisecond)
 		}
-		ext, err := c.CreateOrGetMapping(context.Background())
+		ext, err := c.createOrGetMapping(context.Background())
 		t.Logf("Got: %v, %v", ext, err)
 	}
 }
@@ -31,7 +31,7 @@ func TestClientProbe(t *testing.T) {
 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
 		t.Skip("skipping test without HIT_NETWORK=1")
 	}
-	c := NewClient(t.Logf)
+	c := NewClient(t.Logf, nil)
 	for i := 0; i < 2; i++ {
 		if i > 0 {
 			time.Sleep(100 * time.Millisecond)
@@ -45,10 +45,10 @@ func TestClientProbeThenMap(t *testing.T) {
 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v {
 		t.Skip("skipping test without HIT_NETWORK=1")
 	}
-	c := NewClient(t.Logf)
+	c := NewClient(t.Logf, nil)
 	c.SetLocalPort(1234)
 	res, err := c.Probe(context.Background())
 	t.Logf("Probe: %+v, %v", res, err)
-	ext, err := c.CreateOrGetMapping(context.Background())
-	t.Logf("CreateOrGetMapping: %v, %v", ext, err)
+	ext, err := c.createOrGetMapping(context.Background())
+	t.Logf("createOrGetMapping: %v, %v", ext, err)
 }

+ 11 - 5
wgengine/magicsock/magicsock.go

@@ -486,7 +486,7 @@ func NewConn(opts Options) (*Conn, error) {
 	c.noteRecvActivity = opts.NoteRecvActivity
 	c.simulatedNetwork = opts.SimulatedNetwork
 	c.disableLegacy = opts.DisableLegacyNetworking
-	c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "))
+	c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), c.onPortMapChanged)
 	if opts.LinkMonitor != nil {
 		c.portMapper.SetGatewayLookupFunc(opts.LinkMonitor.GatewayAndSelfIP)
 	}
@@ -979,6 +979,8 @@ func (c *Conn) goDerpConnect(node int) {
 //
 // c.mu must NOT be held.
 func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) {
+	portmapExt, havePortmap := c.portMapper.GetCachedMappingOrStartCreatingOne()
+
 	nr, err := c.updateNetInfo(ctx)
 	if err != nil {
 		c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err)
@@ -1002,11 +1004,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
 		}
 	}
 
-	if ext, err := c.portMapper.CreateOrGetMapping(ctx); err == nil {
-		addAddr(ext, tailcfg.EndpointPortmapped)
+	// If we didn't have a portmap earlier, maybe it's done by now.
+	if !havePortmap {
+		portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne()
+	}
+	if havePortmap {
+		addAddr(portmapExt, tailcfg.EndpointPortmapped)
 		c.setNetInfoHavePortMap()
-	} else if !portmapper.IsNoMappingError(err) {
-		c.logf("portmapper: %v", err)
 	}
 
 	if nr.GlobalV4 != "" {
@@ -2563,6 +2567,8 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool {
 	return true
 }
 
+func (c *Conn) onPortMapChanged() { c.ReSTUN("portmap-changed") }
+
 // ReSTUN triggers an address discovery.
 // The provided why string is for debug logging only.
 func (c *Conn) ReSTUN(why string) {