| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- package derp
- import (
- "bufio"
- "bytes"
- "io"
- "net"
- "reflect"
- "sync"
- "testing"
- "time"
- "tailscale.com/tstest"
- "tailscale.com/types/key"
- )
- type dummyNetConn struct {
- net.Conn
- }
- func (dummyNetConn) SetReadDeadline(time.Time) error { return nil }
- func TestClientRecv(t *testing.T) {
- tests := []struct {
- name string
- input []byte
- want any
- }{
- {
- name: "ping",
- input: []byte{
- byte(FramePing), 0, 0, 0, 8,
- 1, 2, 3, 4, 5, 6, 7, 8,
- },
- want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8},
- },
- {
- name: "pong",
- input: []byte{
- byte(FramePong), 0, 0, 0, 8,
- 1, 2, 3, 4, 5, 6, 7, 8,
- },
- want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8},
- },
- {
- name: "health_bad",
- input: []byte{
- byte(FrameHealth), 0, 0, 0, 3,
- byte('B'), byte('A'), byte('D'),
- },
- want: HealthMessage{Problem: "BAD"},
- },
- {
- name: "health_ok",
- input: []byte{
- byte(FrameHealth), 0, 0, 0, 0,
- },
- want: HealthMessage{},
- },
- {
- name: "server_restarting",
- input: []byte{
- byte(FrameRestarting), 0, 0, 0, 8,
- 0, 0, 0, 1,
- 0, 0, 0, 2,
- },
- want: ServerRestartingMessage{
- ReconnectIn: 1 * time.Millisecond,
- TryFor: 2 * time.Millisecond,
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- c := &Client{
- nc: dummyNetConn{},
- br: bufio.NewReader(bytes.NewReader(tt.input)),
- logf: t.Logf,
- clock: &tstest.Clock{},
- }
- got, err := c.Recv()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("got %#v; want %#v", got, tt.want)
- }
- })
- }
- }
- func TestClientSendPing(t *testing.T) {
- var buf bytes.Buffer
- c := &Client{
- bw: bufio.NewWriter(&buf),
- }
- if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
- t.Fatal(err)
- }
- want := []byte{
- byte(FramePing), 0, 0, 0, 8,
- 1, 2, 3, 4, 5, 6, 7, 8,
- }
- if !bytes.Equal(buf.Bytes(), want) {
- t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
- }
- }
- func TestClientSendPong(t *testing.T) {
- var buf bytes.Buffer
- c := &Client{
- bw: bufio.NewWriter(&buf),
- }
- if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
- t.Fatal(err)
- }
- want := []byte{
- byte(FramePong), 0, 0, 0, 8,
- 1, 2, 3, 4, 5, 6, 7, 8,
- }
- if !bytes.Equal(buf.Bytes(), want) {
- t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
- }
- }
- func BenchmarkWriteUint32(b *testing.B) {
- w := bufio.NewWriter(io.Discard)
- b.ReportAllocs()
- b.ResetTimer()
- for range b.N {
- writeUint32(w, 0x0ba3a)
- }
- }
- type nopRead struct{}
- func (r nopRead) Read(p []byte) (int, error) {
- return len(p), nil
- }
- var sinkU32 uint32
- func BenchmarkReadUint32(b *testing.B) {
- r := bufio.NewReader(nopRead{})
- var err error
- b.ReportAllocs()
- b.ResetTimer()
- for range b.N {
- sinkU32, err = readUint32(r)
- if err != nil {
- b.Fatal(err)
- }
- }
- }
- type countWriter struct {
- mu sync.Mutex
- writes int
- bytes int64
- }
- func (w *countWriter) Write(p []byte) (n int, err error) {
- w.mu.Lock()
- defer w.mu.Unlock()
- w.writes++
- w.bytes += int64(len(p))
- return len(p), nil
- }
- func (w *countWriter) Stats() (writes int, bytes int64) {
- w.mu.Lock()
- defer w.mu.Unlock()
- return w.writes, w.bytes
- }
- func (w *countWriter) ResetStats() {
- w.mu.Lock()
- defer w.mu.Unlock()
- w.writes, w.bytes = 0, 0
- }
- func TestClientSendRateLimiting(t *testing.T) {
- cw := new(countWriter)
- c := &Client{
- bw: bufio.NewWriter(cw),
- clock: &tstest.Clock{},
- }
- c.setSendRateLimiter(ServerInfoMessage{})
- pkt := make([]byte, 1000)
- if err := c.send(key.NodePublic{}, pkt); err != nil {
- t.Fatal(err)
- }
- writes1, bytes1 := cw.Stats()
- if writes1 != 1 {
- t.Errorf("writes = %v, want 1", writes1)
- }
- // Flood should all succeed.
- cw.ResetStats()
- for range 1000 {
- if err := c.send(key.NodePublic{}, pkt); err != nil {
- t.Fatal(err)
- }
- }
- writes1K, bytes1K := cw.Stats()
- if writes1K != 1000 {
- t.Logf("writes = %v; want 1000", writes1K)
- }
- if got, want := bytes1K, bytes1*1000; got != want {
- t.Logf("bytes = %v; want %v", got, want)
- }
- // Set a rate limiter
- cw.ResetStats()
- c.setSendRateLimiter(ServerInfoMessage{
- TokenBucketBytesPerSecond: 1,
- TokenBucketBytesBurst: int(bytes1 * 2),
- })
- for range 1000 {
- if err := c.send(key.NodePublic{}, pkt); err != nil {
- t.Fatal(err)
- }
- }
- writesLimited, bytesLimited := cw.Stats()
- if writesLimited == 0 || writesLimited == writes1K {
- t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
- }
- if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
- t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
- }
- }
|