Просмотр исходного кода

derp: throttle client sends if server advertises rate limits

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 лет назад
Родитель
Сommit
73f177e4d5
3 измененных файлов с 126 добавлено и 7 удалено
  1. 46 7
      derp/derp_client.go
  2. 3 0
      derp/derp_server.go
  3. 77 0
      derp/derp_test.go

+ 46 - 7
derp/derp_client.go

@@ -16,6 +16,7 @@ import (
 	"time"
 
 	"golang.org/x/crypto/nacl/box"
+	"golang.org/x/time/rate"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 )
@@ -32,8 +33,9 @@ type Client struct {
 	canAckPings bool
 	isProber    bool
 
-	wmu sync.Mutex // hold while writing to bw
-	bw  *bufio.Writer
+	wmu  sync.Mutex // hold while writing to bw
+	bw   *bufio.Writer
+	rate *rate.Limiter // if non-nil, rate limiter to use
 
 	// Owned by Recv:
 	peeked  int   // bytes to discard on next Recv
@@ -217,7 +219,12 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
 
 	c.wmu.Lock()
 	defer c.wmu.Unlock()
-
+	if c.rate != nil {
+		pktLen := frameHeaderLen + len(dstKey) + len(pkt)
+		if !c.rate.AllowN(time.Now(), pktLen) {
+			return nil // drop
+		}
+	}
 	if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil {
 		return err
 	}
@@ -353,7 +360,22 @@ type PeerPresentMessage key.Public
 func (PeerPresentMessage) msg() {}
 
 // ServerInfoMessage is sent by the server upon first connect.
-type ServerInfoMessage struct{}
+type ServerInfoMessage struct {
+	// TokenBucketBytesPerSecond is how many bytes per second the
+	// server says it will accept, including all framing bytes.
+	//
+	// Zero means unspecified. There might be a limit, but the
+	// client need not try to respect it.
+	TokenBucketBytesPerSecond int
+
+	// TokenBucketBytesBurst is how many bytes the server will
+	// allow to burst, temporarily violating
+	// TokenBucketBytesPerSecond.
+	//
+	// Zero means unspecified. There might be a limit, but the
+	// client need not try to respect it.
+	TokenBucketBytesBurst int
+}
 
 func (ServerInfoMessage) msg() {}
 
@@ -475,12 +497,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 			// needing to wait an RTT to discover the version at startup.
 			// We'd prefer to give the connection to the client (magicsock)
 			// to start writing as soon as possible.
-			_, err := c.parseServerInfo(b)
+			si, err := c.parseServerInfo(b)
 			if err != nil {
 				return nil, fmt.Errorf("invalid server info frame: %v", err)
 			}
-			// TODO: add the results of parseServerInfo to ServerInfoMessage if we ever need it.
-			return ServerInfoMessage{}, nil
+			sm := ServerInfoMessage{
+				TokenBucketBytesPerSecond: si.TokenBucketBytesPerSecond,
+				TokenBucketBytesBurst:     si.TokenBucketBytesBurst,
+			}
+			c.setSendRateLimiter(sm)
+			return sm, nil
 		case frameKeepAlive:
 			// A one-way keep-alive message that doesn't require an acknowledgement.
 			// This predated framePing/framePong.
@@ -537,3 +563,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 		}
 	}
 }
+
+func (c *Client) setSendRateLimiter(sm ServerInfoMessage) {
+	c.wmu.Lock()
+	defer c.wmu.Unlock()
+
+	if sm.TokenBucketBytesPerSecond == 0 {
+		c.rate = nil
+	} else {
+		c.rate = rate.NewLimiter(
+			rate.Limit(sm.TokenBucketBytesPerSecond),
+			sm.TokenBucketBytesBurst)
+	}
+}

+ 3 - 0
derp/derp_server.go

@@ -1079,6 +1079,9 @@ func (s *Server) noteClientActivity(c *sclient) {
 
 type serverInfo struct {
 	Version int `json:"version,omitempty"`
+
+	TokenBucketBytesPerSecond int `json:",omitempty"`
+	TokenBucketBytesBurst     int `json:",omitempty"`
 }
 
 func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.Public) error {

+ 77 - 0
derp/derp_test.go

@@ -1244,3 +1244,80 @@ func TestParseSSOutput(t *testing.T) {
 		t.Errorf("parseSSOutput expected non-empty map")
 	}
 }
+
+type countWriter struct {
+	mu     sync.Mutex
+	writes int
+	bytes  int64
+}
+
+func (w *countWriter) Write(p []byte) (n int, err error) {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	w.writes++
+	w.bytes += int64(len(p))
+	return len(p), nil
+}
+
+func (w *countWriter) Stats() (writes int, bytes int64) {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	return w.writes, w.bytes
+}
+
+func (w *countWriter) ResetStats() {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	w.writes, w.bytes = 0, 0
+}
+
+func TestClientSendRateLimiting(t *testing.T) {
+	cw := new(countWriter)
+	c := &Client{
+		bw: bufio.NewWriter(cw),
+	}
+	c.setSendRateLimiter(ServerInfoMessage{})
+
+	pkt := make([]byte, 1000)
+	if err := c.send(key.Public{}, pkt); err != nil {
+		t.Fatal(err)
+	}
+	writes1, bytes1 := cw.Stats()
+	if writes1 != 1 {
+		t.Errorf("writes = %v, want 1", writes1)
+	}
+
+	// Flood should all succeed.
+	cw.ResetStats()
+	for i := 0; i < 1000; i++ {
+		if err := c.send(key.Public{}, pkt); err != nil {
+			t.Fatal(err)
+		}
+	}
+	writes1K, bytes1K := cw.Stats()
+	if writes1K != 1000 {
+		t.Logf("writes = %v; want 1000", writes1K)
+	}
+	if got, want := bytes1K, bytes1*1000; got != want {
+		t.Logf("bytes = %v; want %v", got, want)
+	}
+
+	// Set a rate limiter
+	cw.ResetStats()
+	c.setSendRateLimiter(ServerInfoMessage{
+		TokenBucketBytesPerSecond: 1,
+		TokenBucketBytesBurst:     int(bytes1 * 2),
+	})
+	for i := 0; i < 1000; i++ {
+		if err := c.send(key.Public{}, pkt); err != nil {
+			t.Fatal(err)
+		}
+	}
+	writesLimited, bytesLimited := cw.Stats()
+	if writesLimited == 0 || writesLimited == writes1K {
+		t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
+	}
+	if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
+		t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
+	}
+}