client_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package derp
  4. import (
  5. "bufio"
  6. "bytes"
  7. "io"
  8. "net"
  9. "reflect"
  10. "sync"
  11. "testing"
  12. "time"
  13. "tailscale.com/tstest"
  14. "tailscale.com/types/key"
  15. )
  16. type dummyNetConn struct {
  17. net.Conn
  18. }
  19. func (dummyNetConn) SetReadDeadline(time.Time) error { return nil }
  20. func TestClientRecv(t *testing.T) {
  21. tests := []struct {
  22. name string
  23. input []byte
  24. want any
  25. }{
  26. {
  27. name: "ping",
  28. input: []byte{
  29. byte(FramePing), 0, 0, 0, 8,
  30. 1, 2, 3, 4, 5, 6, 7, 8,
  31. },
  32. want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8},
  33. },
  34. {
  35. name: "pong",
  36. input: []byte{
  37. byte(FramePong), 0, 0, 0, 8,
  38. 1, 2, 3, 4, 5, 6, 7, 8,
  39. },
  40. want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8},
  41. },
  42. {
  43. name: "health_bad",
  44. input: []byte{
  45. byte(FrameHealth), 0, 0, 0, 3,
  46. byte('B'), byte('A'), byte('D'),
  47. },
  48. want: HealthMessage{Problem: "BAD"},
  49. },
  50. {
  51. name: "health_ok",
  52. input: []byte{
  53. byte(FrameHealth), 0, 0, 0, 0,
  54. },
  55. want: HealthMessage{},
  56. },
  57. {
  58. name: "server_restarting",
  59. input: []byte{
  60. byte(FrameRestarting), 0, 0, 0, 8,
  61. 0, 0, 0, 1,
  62. 0, 0, 0, 2,
  63. },
  64. want: ServerRestartingMessage{
  65. ReconnectIn: 1 * time.Millisecond,
  66. TryFor: 2 * time.Millisecond,
  67. },
  68. },
  69. }
  70. for _, tt := range tests {
  71. t.Run(tt.name, func(t *testing.T) {
  72. c := &Client{
  73. nc: dummyNetConn{},
  74. br: bufio.NewReader(bytes.NewReader(tt.input)),
  75. logf: t.Logf,
  76. clock: &tstest.Clock{},
  77. }
  78. got, err := c.Recv()
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. if !reflect.DeepEqual(got, tt.want) {
  83. t.Errorf("got %#v; want %#v", got, tt.want)
  84. }
  85. })
  86. }
  87. }
  88. func TestClientSendPing(t *testing.T) {
  89. var buf bytes.Buffer
  90. c := &Client{
  91. bw: bufio.NewWriter(&buf),
  92. }
  93. if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
  94. t.Fatal(err)
  95. }
  96. want := []byte{
  97. byte(FramePing), 0, 0, 0, 8,
  98. 1, 2, 3, 4, 5, 6, 7, 8,
  99. }
  100. if !bytes.Equal(buf.Bytes(), want) {
  101. t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
  102. }
  103. }
  104. func TestClientSendPong(t *testing.T) {
  105. var buf bytes.Buffer
  106. c := &Client{
  107. bw: bufio.NewWriter(&buf),
  108. }
  109. if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
  110. t.Fatal(err)
  111. }
  112. want := []byte{
  113. byte(FramePong), 0, 0, 0, 8,
  114. 1, 2, 3, 4, 5, 6, 7, 8,
  115. }
  116. if !bytes.Equal(buf.Bytes(), want) {
  117. t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
  118. }
  119. }
  120. func BenchmarkWriteUint32(b *testing.B) {
  121. w := bufio.NewWriter(io.Discard)
  122. b.ReportAllocs()
  123. b.ResetTimer()
  124. for range b.N {
  125. writeUint32(w, 0x0ba3a)
  126. }
  127. }
  128. type nopRead struct{}
  129. func (r nopRead) Read(p []byte) (int, error) {
  130. return len(p), nil
  131. }
  132. var sinkU32 uint32
  133. func BenchmarkReadUint32(b *testing.B) {
  134. r := bufio.NewReader(nopRead{})
  135. var err error
  136. b.ReportAllocs()
  137. b.ResetTimer()
  138. for range b.N {
  139. sinkU32, err = readUint32(r)
  140. if err != nil {
  141. b.Fatal(err)
  142. }
  143. }
  144. }
  145. type countWriter struct {
  146. mu sync.Mutex
  147. writes int
  148. bytes int64
  149. }
  150. func (w *countWriter) Write(p []byte) (n int, err error) {
  151. w.mu.Lock()
  152. defer w.mu.Unlock()
  153. w.writes++
  154. w.bytes += int64(len(p))
  155. return len(p), nil
  156. }
  157. func (w *countWriter) Stats() (writes int, bytes int64) {
  158. w.mu.Lock()
  159. defer w.mu.Unlock()
  160. return w.writes, w.bytes
  161. }
  162. func (w *countWriter) ResetStats() {
  163. w.mu.Lock()
  164. defer w.mu.Unlock()
  165. w.writes, w.bytes = 0, 0
  166. }
  167. func TestClientSendRateLimiting(t *testing.T) {
  168. cw := new(countWriter)
  169. c := &Client{
  170. bw: bufio.NewWriter(cw),
  171. clock: &tstest.Clock{},
  172. }
  173. c.setSendRateLimiter(ServerInfoMessage{})
  174. pkt := make([]byte, 1000)
  175. if err := c.send(key.NodePublic{}, pkt); err != nil {
  176. t.Fatal(err)
  177. }
  178. writes1, bytes1 := cw.Stats()
  179. if writes1 != 1 {
  180. t.Errorf("writes = %v, want 1", writes1)
  181. }
  182. // Flood should all succeed.
  183. cw.ResetStats()
  184. for range 1000 {
  185. if err := c.send(key.NodePublic{}, pkt); err != nil {
  186. t.Fatal(err)
  187. }
  188. }
  189. writes1K, bytes1K := cw.Stats()
  190. if writes1K != 1000 {
  191. t.Logf("writes = %v; want 1000", writes1K)
  192. }
  193. if got, want := bytes1K, bytes1*1000; got != want {
  194. t.Logf("bytes = %v; want %v", got, want)
  195. }
  196. // Set a rate limiter
  197. cw.ResetStats()
  198. c.setSendRateLimiter(ServerInfoMessage{
  199. TokenBucketBytesPerSecond: 1,
  200. TokenBucketBytesBurst: int(bytes1 * 2),
  201. })
  202. for range 1000 {
  203. if err := c.send(key.NodePublic{}, pkt); err != nil {
  204. t.Fatal(err)
  205. }
  206. }
  207. writesLimited, bytesLimited := cw.Stats()
  208. if writesLimited == 0 || writesLimited == writes1K {
  209. t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
  210. }
  211. if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
  212. t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
  213. }
  214. }