瀏覽代碼

cmd/derper: support forwarding packets amongst set of peer DERP servers

Updates #388

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 年之前
父節點
當前提交
1cb7dab881
共有 7 個文件被更改,包括 549 次插入26 次删除
  1. 3 4
      cmd/derper/derper.go
  2. 147 0
      cmd/derper/mesh.go
  3. 1 0
      derp/derp.go
  4. 38 0
      derp/derp_client.go
  5. 231 22
      derp/derp_server.go
  6. 114 0
      derp/derp_test.go
  7. 15 0
      derp/derphttp/derphttp_client.go

+ 3 - 4
cmd/derper/derper.go

@@ -134,10 +134,9 @@ func main() {
 		s.SetMeshKey(key)
 		log.Printf("DERP mesh key configured")
 	}
-
-	// TODO(bradfitz): parse & use the *meshWith
-	_ = *meshWith
-
+	if err := startMesh(s); err != nil {
+		log.Fatalf("startMesh: %v", err)
+	}
 	expvar.Publish("derp", s.ExpVar())
 
 	// Create our own mux so we don't expose /debug/ stuff to the world.

+ 147 - 0
cmd/derper/mesh.go

@@ -0,0 +1,147 @@
+// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"strings"
+	"sync"
+	"time"
+
+	"tailscale.com/derp"
+	"tailscale.com/derp/derphttp"
+	"tailscale.com/types/key"
+	"tailscale.com/types/logger"
+)
+
+func startMesh(s *derp.Server) error {
+	if *meshWith == "" {
+		return nil
+	}
+	if !s.HasMeshKey() {
+		return errors.New("--mesh-with requires --mesh-psk-file")
+	}
+	for _, host := range strings.Split(*meshWith, ",") {
+		if err := startMeshWithHost(s, host); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func startMeshWithHost(s *derp.Server, host string) error {
+	logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host))
+	c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf)
+	if err != nil {
+		return err
+	}
+	c.MeshKey = s.MeshKey()
+	go runMeshClient(s, host, c, logf)
+	return nil
+}
+
+func runMeshClient(s *derp.Server, host string, c *derphttp.Client, logf logger.Logf) {
+	const retryInterval = 5 * time.Second
+	const statusInterval = 10 * time.Second
+	var (
+		mu              sync.Mutex
+		present         = map[key.Public]bool{}
+		loggedConnected = false
+	)
+	clear := func() {
+		mu.Lock()
+		defer mu.Unlock()
+		if len(present) == 0 {
+			return
+		}
+		logf("reconnected; clearing %d forwarding mappings", len(present))
+		for k := range present {
+			s.RemovePacketForwarder(k, c)
+		}
+		present = map[key.Public]bool{}
+	}
+	lastConnGen := 0
+	lastStatus := time.Now()
+	logConnectedLocked := func() {
+		if loggedConnected {
+			return
+		}
+		logf("connected; %d peers", len(present))
+		loggedConnected = true
+	}
+
+	const logConnectedDelay = 200 * time.Millisecond
+	timer := time.AfterFunc(2*time.Second, func() {
+		mu.Lock()
+		defer mu.Unlock()
+		logConnectedLocked()
+	})
+	defer timer.Stop()
+
+	updatePeer := func(k key.Public, isPresent bool) {
+		if isPresent {
+			s.AddPacketForwarder(k, c)
+		} else {
+			s.RemovePacketForwarder(k, c)
+		}
+
+		mu.Lock()
+		defer mu.Unlock()
+		if isPresent {
+			present[k] = true
+			if !loggedConnected {
+				timer.Reset(logConnectedDelay)
+			}
+		} else {
+			// If we got a peerGone message, that means the initial connection's
+			// flood of peerPresent messages is done, so we can log already:
+			logConnectedLocked()
+			delete(present, k)
+		}
+	}
+
+	for {
+		err := c.WatchConnectionChanges()
+		if err != nil {
+			clear()
+			logf("WatchConnectionChanges: %v", err)
+			time.Sleep(retryInterval)
+			continue
+		}
+
+		if c.ServerPublicKey() == s.PublicKey() {
+			logf("detected self-connect; ignoring host")
+			return
+		}
+		for {
+			var buf [64 << 10]byte
+			m, connGen, err := c.RecvDetail(buf[:])
+			if err != nil {
+				clear()
+				logf("Recv: %v", err)
+				time.Sleep(retryInterval)
+				break
+			}
+			if connGen != lastConnGen {
+				lastConnGen = connGen
+				clear()
+			}
+			switch m := m.(type) {
+			case derp.PeerPresentMessage:
+				updatePeer(key.Public(m), true)
+			case derp.PeerGoneMessage:
+				updatePeer(key.Public(m), false)
+			default:
+				continue
+			}
+			if now := time.Now(); now.Sub(lastStatus) > statusInterval {
+				lastStatus = now
+				logf("%d peers", len(present))
+			}
+		}
+	}
+}

+ 1 - 0
derp/derp.go

@@ -72,6 +72,7 @@ const (
 	frameClientInfo    = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json)
 	frameServerInfo    = frameType(0x03) // 24B nonce + naclbox(json)
 	frameSendPacket    = frameType(0x04) // 32B dest pub key + packet bytes
+	frameForwardPacket = frameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes
 	frameRecvPacket    = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes
 	frameKeepAlive     = frameType(0x06) // no payload, no-op (to be replaced with ping/pong)
 	frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node

+ 38 - 0
derp/derp_client.go

@@ -19,6 +19,7 @@ import (
 	"tailscale.com/types/logger"
 )
 
+// Client is a DERP client.
 type Client struct {
 	serverKey    key.Public // of the DERP server; not a machine or node key
 	privateKey   key.Private
@@ -170,6 +171,9 @@ func (c *Client) sendClientKey() error {
 	return writeFrame(c.bw, frameClientInfo, buf)
 }
 
+// ServerPublicKey returns the server's public key.
+func (c *Client) ServerPublicKey() key.Public { return c.serverKey }
+
 // Send sends a packet to the Tailscale node identified by dstKey.
 //
 // It is an error if the packet is larger than 64KB.
@@ -201,6 +205,40 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
 	return c.bw.Flush()
 }
 
+func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error) {
+	defer func() {
+		if err != nil {
+			err = fmt.Errorf("derp.ForwardPacket: %w", err)
+		}
+	}()
+
+	if len(pkt) > MaxPacketSize {
+		return fmt.Errorf("packet too big: %d", len(pkt))
+	}
+
+	c.wmu.Lock()
+	defer c.wmu.Unlock()
+
+	timer := time.AfterFunc(5*time.Second, c.writeTimeoutFired)
+	defer timer.Stop()
+
+	if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil {
+		return err
+	}
+	if _, err := c.bw.Write(srcKey[:]); err != nil {
+		return err
+	}
+	if _, err := c.bw.Write(dstKey[:]); err != nil {
+		return err
+	}
+	if _, err := c.bw.Write(pkt); err != nil {
+		return err
+	}
+	return c.bw.Flush()
+}
+
+func (c *Client) writeTimeoutFired() { c.nc.Close() }
+
 // NotePreferred sends a packet that tells the server whether this
 // client is the user's preferred server. This is only used in the
 // server for stats.

+ 231 - 22
derp/derp_server.go

@@ -50,30 +50,46 @@ type Server struct {
 	meshKey    string
 
 	// Counters:
-	packetsSent, bytesSent  expvar.Int
-	packetsRecv, bytesRecv  expvar.Int
-	packetsDropped          expvar.Int
-	packetsDroppedReason    metrics.LabelMap
-	packetsDroppedUnknown   *expvar.Int // unknown dst pubkey
-	packetsDroppedGone      *expvar.Int // dst conn shutting down
-	packetsDroppedQueueHead *expvar.Int // queue full, drop head packet
-	packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet
-	packetsDroppedWrite     *expvar.Int // error writing to dst conn
-	peerGoneFrames          expvar.Int  // number of peer gone frames sent
-	accepts                 expvar.Int
-	curClients              expvar.Int
-	curHomeClients          expvar.Int // ones with preferred
-	clientsReplaced         expvar.Int
-	unknownFrames           expvar.Int
-	homeMovesIn             expvar.Int // established clients announce home server moves in
-	homeMovesOut            expvar.Int // established clients announce home server moves out
+	packetsSent, bytesSent   expvar.Int
+	packetsRecv, bytesRecv   expvar.Int
+	packetsDropped           expvar.Int
+	packetsDroppedReason     metrics.LabelMap
+	packetsDroppedUnknown    *expvar.Int // unknown dst pubkey
+	packetsDroppedFwdUnknown *expvar.Int // unknown dst pubkey on forward
+	packetsDroppedGone       *expvar.Int // dst conn shutting down
+	packetsDroppedQueueHead  *expvar.Int // queue full, drop head packet
+	packetsDroppedQueueTail  *expvar.Int // queue full, drop tail packet
+	packetsDroppedWrite      *expvar.Int // error writing to dst conn
+	packetsForwardedOut      expvar.Int
+	packetsForwardedIn       expvar.Int
+	peerGoneFrames           expvar.Int // number of peer gone frames sent
+	accepts                  expvar.Int
+	curClients               expvar.Int
+	curHomeClients           expvar.Int // ones with preferred
+	clientsReplaced          expvar.Int
+	unknownFrames            expvar.Int
+	homeMovesIn              expvar.Int // established clients announce home server moves in
+	homeMovesOut             expvar.Int // established clients announce home server moves out
+	multiForwarderCreated    expvar.Int
+	multiForwarderDeleted    expvar.Int
 
 	mu          sync.Mutex
 	closed      bool
 	netConns    map[Conn]chan struct{} // chan is closed when conn closes
 	clients     map[key.Public]*sclient
-	clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
-	watchers    map[*sclient]bool   // mesh peer -> true
+	clientsEver map[key.Public]bool            // never deleted from, for stats; fine for now
+	watchers    map[*sclient]bool              // mesh peer -> true
+	clientsMesh map[key.Public]PacketForwarder // clients connected to mesh peers; nil means only in clients, not remote
+}
+
+// PacketForwarder is something that can forward packets.
+//
+// It's mostly an inteface for circular dependency reasons; the
+// typical implementation is derphttp.Client. The other implementation
+// is a multiForwarder, which this package creates as needed if a
+// public key gets more than one PacketForwarder registered for it.
+type PacketForwarder interface {
+	ForwardPacket(src, dst key.Public, payload []byte) error
 }
 
 // Conn is the subset of the underlying net.Conn the DERP Server needs.
@@ -101,11 +117,13 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
 		packetsDroppedReason: metrics.LabelMap{Label: "reason"},
 		clients:              make(map[key.Public]*sclient),
 		clientsEver:          make(map[key.Public]bool),
+		clientsMesh:          map[key.Public]PacketForwarder{},
 		netConns:             make(map[Conn]chan struct{}),
 		memSys0:              ms.Sys,
 		watchers:             map[*sclient]bool{},
 	}
 	s.packetsDroppedUnknown = s.packetsDroppedReason.Get("unknown_dest")
+	s.packetsDroppedFwdUnknown = s.packetsDroppedReason.Get("unknown_dest_on_fwd")
 	s.packetsDroppedGone = s.packetsDroppedReason.Get("gone")
 	s.packetsDroppedQueueHead = s.packetsDroppedReason.Get("queue_head")
 	s.packetsDroppedQueueTail = s.packetsDroppedReason.Get("queue_tail")
@@ -210,6 +228,9 @@ func (s *Server) registerClient(c *sclient) {
 	}
 	s.clients[c.key] = c
 	s.clientsEver[c.key] = true
+	if _, ok := s.clientsMesh[c.key]; !ok {
+		s.clientsMesh[c.key] = nil // just for varz of total users in cluster
+	}
 	s.curClients.Add(1)
 	s.broadcastPeerStateChangeLocked(c.key, true)
 }
@@ -238,6 +259,9 @@ func (s *Server) unregisterClient(c *sclient) {
 	if c.canMesh {
 		delete(s.watchers, c)
 	}
+	if v, ok := s.clientsMesh[c.key]; ok && v == nil {
+		delete(s.clientsMesh, c.key)
+	}
 	s.broadcastPeerStateChangeLocked(c.key, false)
 
 	s.curClients.Add(-1)
@@ -271,8 +295,6 @@ func (s *Server) addWatcher(c *sclient) {
 
 	if c.key == s.publicKey {
 		// We're connecting to ourself. Do nothing.
-		// TODO(bradfitz): have client notice and disconnect
-		// so an idle TCP connection isn't kept open.
 		return
 	}
 
@@ -378,6 +400,8 @@ func (c *sclient) run(ctx context.Context) error {
 			err = c.handleFrameNotePreferred(ft, fl)
 		case frameSendPacket:
 			err = c.handleFrameSendPacket(ft, fl)
+		case frameForwardPacket:
+			err = c.handleFrameForwardPacket(ft, fl)
 		case frameWatchConns:
 			err = c.handleFrameWatchConns(ft, fl)
 		default:
@@ -417,6 +441,42 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error {
 	return nil
 }
 
+// handleFrameForwardPacket reads a "forward packet" frame from the client
+// (which must be a trusted client, a peer in our mesh).
+func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
+	if !c.canMesh {
+		return fmt.Errorf("insufficient permissions")
+	}
+	s := c.s
+
+	srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl)
+	if err != nil {
+		return fmt.Errorf("client %x: recvForwardPacket: %v", c.key, err)
+	}
+	s.packetsForwardedIn.Add(1)
+
+	s.mu.Lock()
+	dst := s.clients[dstKey]
+	// TODO(bradfitz): think about the sentTo/Issue 150 optimization
+	// in the context of DERP meshes.
+	s.mu.Unlock()
+
+	if dst == nil {
+		s.packetsDropped.Add(1)
+		s.packetsDroppedFwdUnknown.Add(1)
+		if debug {
+			c.logf("dropping forwarded packet for unknown %x", dstKey)
+		}
+		return nil
+	}
+
+	return c.sendPkt(dst, pkt{
+		bs:  contents,
+		src: srcKey,
+	})
+}
+
+// handleFrameSendPacket reads a "send packet" frame from the client.
 func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
 	s := c.s
 
@@ -425,9 +485,12 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
 		return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
 	}
 
+	var fwd PacketForwarder
 	s.mu.Lock()
 	dst := s.clients[dstKey]
-	if dst != nil {
+	if dst == nil {
+		fwd = s.clientsMesh[dstKey]
+	} else {
 		// Track that we've sent to this peer, so if/when we
 		// disconnect first, the server can inform all our old
 		// recipients that we're gone. (Issue 150 optimization)
@@ -436,6 +499,14 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
 	s.mu.Unlock()
 
 	if dst == nil {
+		if fwd != nil {
+			s.packetsForwardedOut.Add(1)
+			if err := fwd.ForwardPacket(c.key, dstKey, contents); err != nil {
+				// TODO:
+				return nil
+			}
+			return nil
+		}
 		s.packetsDropped.Add(1)
 		s.packetsDroppedUnknown.Add(1)
 		if debug {
@@ -450,6 +521,13 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
 	if dst.info.Version >= protocolSrcAddrs {
 		p.src = c.key
 	}
+	return c.sendPkt(dst, p)
+}
+
+func (c *sclient) sendPkt(dst *sclient, p pkt) error {
+	s := c.s
+	dstKey := dst.key
+
 	// Attempt to queue for sending up to 3 times. On each attempt, if
 	// the queue is full, try to drop from queue head to prioritize
 	// fresher packets.
@@ -615,6 +693,29 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Publi
 // zpub is the key.Public zero value.
 var zpub key.Public
 
+func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.Public, contents []byte, err error) {
+	if frameLen < keyLen*2 {
+		return zpub, zpub, nil, errors.New("short send packet frame")
+	}
+	if _, err := io.ReadFull(br, srcKey[:]); err != nil {
+		return zpub, zpub, nil, err
+	}
+	if _, err := io.ReadFull(br, dstKey[:]); err != nil {
+		return zpub, zpub, nil, err
+	}
+	packetLen := frameLen - keyLen*2
+	if packetLen > MaxPacketSize {
+		return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
+	}
+	contents = make([]byte, packetLen)
+	if _, err := io.ReadFull(br, contents); err != nil {
+		return zpub, zpub, nil, err
+	}
+	// TODO: was s.packetsRecv.Add(1)
+	// TODO: was s.bytesRecv.Add(int64(len(contents)))
+	return srcKey, dstKey, contents, nil
+}
+
 // sclient is a client connection to the server.
 //
 // (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
@@ -889,6 +990,108 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) {
 	return err
 }
 
+// AddPacketForwarder registers fwd as a packet forwarder for dst.
+// fwd must be comparable.
+func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if prev, ok := s.clientsMesh[dst]; ok {
+		if prev == fwd {
+			// Duplicate registration of same forwarder. Ignore.
+			return
+		}
+		if m, ok := prev.(multiForwarder); ok {
+			if _, ok := m[fwd]; !ok {
+				// Duplicate registration of same forwarder in set; ignore.
+				return
+			}
+			m[fwd] = m.maxVal() + 1
+			return
+		}
+		// Otherwise, the existing value is not a set and not a dup, so make it a set.
+		fwd = multiForwarder{
+			prev: 1, // existed 1st, higher priority
+			fwd:  2, // the passed in fwd is in 2nd place
+		}
+		s.multiForwarderCreated.Add(1)
+	}
+	s.clientsMesh[dst] = fwd
+}
+
+// RemovePacketForwarder removes fwd as a packet forwarder for dst.
+// fwd must be comparable.
+func (s *Server) RemovePacketForwarder(dst key.Public, fwd PacketForwarder) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	v, ok := s.clientsMesh[dst]
+	if !ok {
+		return
+	}
+	if m, ok := v.(multiForwarder); ok {
+		if len(m) < 2 {
+			panic("unexpected")
+		}
+		delete(m, fwd)
+		// If fwd was in m and we no longer need to be a
+		// multiForwarder, replace the entry with the
+		// remaining PacketForwarder.
+		if len(m) == 1 {
+			var remain PacketForwarder
+			for k := range m {
+				remain = k
+			}
+			s.clientsMesh[dst] = remain
+			s.multiForwarderDeleted.Add(1)
+		}
+		return
+	}
+	if v != fwd {
+		// Delete of an entry that wasn't in the
+		// map. Harmless, so ignore.
+		// (This might happen if a user is moving around
+		// between nodes and/or the server sent duplicate
+		// connection change broadcasts.)
+		return
+	}
+
+	if _, isLocal := s.clients[dst]; isLocal {
+		s.clientsMesh[dst] = nil
+	} else {
+		delete(s.clientsMesh, dst)
+	}
+}
+
+// multiForwarder is a PacketForwarder that represents a set of
+// forwarding options. It's used in the rare cases that a client is
+// connected to multiple DERP nodes in a region. That shouldn't really
+// happen except for perhaps during brief moments while the client is
+// reconfiguring, in which case we don't want to forget where the
+// client is. The map value is unique connection number; the lowest
+// one has been seen the longest. It's used to make sure we forward
+// packets consistently to the same node and don't pick randomly.
+type multiForwarder map[PacketForwarder]uint8
+
+func (m multiForwarder) maxVal() (max uint8) {
+	for _, v := range m {
+		if v > max {
+			max = v
+		}
+	}
+	return
+}
+
+func (m multiForwarder) ForwardPacket(src, dst key.Public, payload []byte) error {
+	var fwd PacketForwarder
+	var lowest uint8
+	for k, v := range m {
+		if fwd == nil || v < lowest {
+			fwd = k
+			lowest = v
+		}
+	}
+	return fwd.ForwardPacket(src, dst, payload)
+}
+
 func (s *Server) expVarFunc(f func() interface{}) expvar.Func {
 	return expvar.Func(func() interface{} {
 		s.mu.Lock()
@@ -905,6 +1108,8 @@ func (s *Server) ExpVar() expvar.Var {
 	m.Set("gauge_watchers", s.expVarFunc(func() interface{} { return len(s.watchers) }))
 	m.Set("gauge_current_connnections", &s.curClients)
 	m.Set("gauge_current_home_connnections", &s.curHomeClients)
+	m.Set("gauge_clients_total", expvar.Func(func() interface{} { return len(s.clientsMesh) }))
+	m.Set("gauge_clients_remote", expvar.Func(func() interface{} { return len(s.clientsMesh) - len(s.clients) }))
 	m.Set("accepts", &s.accepts)
 	m.Set("clients_replaced", &s.clientsReplaced)
 	m.Set("bytes_received", &s.bytesRecv)
@@ -917,5 +1122,9 @@ func (s *Server) ExpVar() expvar.Var {
 	m.Set("home_moves_in", &s.homeMovesIn)
 	m.Set("home_moves_out", &s.homeMovesOut)
 	m.Set("peer_gone_frames", &s.peerGoneFrames)
+	m.Set("packets_forwarded_out", &s.packetsForwardedOut)
+	m.Set("packets_forwarded_in", &s.packetsForwardedIn)
+	m.Set("multiforwarder_created", &s.multiForwarderCreated)
+	m.Set("multiforwarder_deleted", &s.multiForwarderDeleted)
 	return m
 }

+ 114 - 0
derp/derp_test.go

@@ -13,6 +13,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"reflect"
 	"sync"
 	"testing"
 	"time"
@@ -619,3 +620,116 @@ func TestWatch(t *testing.T) {
 	w2.wantGone(t, c1.pub)
 	w3.wantGone(t, c1.pub)
 }
+
+type testFwd int
+
+func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") }
+
+func pubAll(b byte) (ret key.Public) {
+	for i := range ret {
+		ret[i] = b
+	}
+	return
+}
+
+func TestForwarderRegistration(t *testing.T) {
+	s := &Server{
+		clients:     make(map[key.Public]*sclient),
+		clientsMesh: map[key.Public]PacketForwarder{},
+	}
+	want := func(want map[key.Public]PacketForwarder) {
+		t.Helper()
+		if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
+			t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
+		}
+	}
+	wantCounter := func(c *expvar.Int, want int) {
+		t.Helper()
+		if got := c.Value(); got != int64(want) {
+			t.Errorf("counter = %v; want %v", got, want)
+		}
+	}
+
+	u1 := pubAll(1)
+	u2 := pubAll(2)
+	u3 := pubAll(3)
+
+	s.AddPacketForwarder(u1, testFwd(1))
+	s.AddPacketForwarder(u2, testFwd(2))
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(1),
+		u2: testFwd(2),
+	})
+
+	// Verify a remove of non-registered forwarder is no-op.
+	s.RemovePacketForwarder(u2, testFwd(999))
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(1),
+		u2: testFwd(2),
+	})
+
+	// Verify a remove of non-registered user is no-op.
+	s.RemovePacketForwarder(u3, testFwd(1))
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(1),
+		u2: testFwd(2),
+	})
+
+	// Actual removal.
+	s.RemovePacketForwarder(u2, testFwd(2))
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(1),
+	})
+
+	// Adding a dup for a user.
+	wantCounter(&s.multiForwarderCreated, 0)
+	s.AddPacketForwarder(u1, testFwd(100))
+	want(map[key.Public]PacketForwarder{
+		u1: multiForwarder{
+			testFwd(1):   1,
+			testFwd(100): 2,
+		},
+	})
+	wantCounter(&s.multiForwarderCreated, 1)
+
+	// Removing a forwarder in a multi set that doesn't exist; does nothing.
+	s.RemovePacketForwarder(u1, testFwd(55))
+	want(map[key.Public]PacketForwarder{
+		u1: multiForwarder{
+			testFwd(1):   1,
+			testFwd(100): 2,
+		},
+	})
+
+	// Removing a forwarder in a multi set that does exist should collapse it away
+	// from being a multiForwarder.
+	wantCounter(&s.multiForwarderDeleted, 0)
+	s.RemovePacketForwarder(u1, testFwd(1))
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(100),
+	})
+	wantCounter(&s.multiForwarderDeleted, 1)
+
+	// Removing an entry for a client that's still connected locally should result
+	// in a nil forwarder.
+	u1c := &sclient{
+		key:  u1,
+		logf: logger.Discard,
+	}
+	s.clients[u1] = u1c
+	s.RemovePacketForwarder(u1, testFwd(100))
+	want(map[key.Public]PacketForwarder{
+		u1: nil,
+	})
+
+	// But once that client disconnects, it should go away.
+	s.unregisterClient(u1c)
+	want(map[key.Public]PacketForwarder{})
+
+	// But if it already has a forwarder, it's not removed.
+	s.AddPacketForwarder(u1, testFwd(2))
+	s.unregisterClient(u1c)
+	want(map[key.Public]PacketForwarder{
+		u1: testFwd(2),
+	})
+}

+ 15 - 0
derp/derphttp/derphttp_client.go

@@ -114,6 +114,9 @@ func (c *Client) Connect(ctx context.Context) error {
 }
 
 // ServerPublicKey returns the server's public key.
+//
+// It only returns a non-zero value once a connection has succeeded
+// from an earlier call.
 func (c *Client) ServerPublicKey() key.Public {
 	c.mu.Lock()
 	defer c.mu.Unlock()
@@ -293,6 +296,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
 		}
 	}
 
+	c.serverPubKey = derpClient.ServerPublicKey()
 	c.client = derpClient
 	c.netConn = tcpConn
 	c.connGen++
@@ -484,6 +488,17 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
 	return err
 }
 
+func (c *Client) ForwardPacket(from, to key.Public, b []byte) error {
+	client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket")
+	if err != nil {
+		return err
+	}
+	if err := client.ForwardPacket(from, to, b); err != nil {
+		c.closeForReconnect(client)
+	}
+	return err
+}
+
 // NotePreferred notes whether this Client is the caller's preferred
 // (home) DERP node. It's only used for stats.
 func (c *Client) NotePreferred(v bool) {