conn_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. // Copyright (c) Tailscale Inc & contributors
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlbase
  4. import (
  5. "bufio"
  6. "bytes"
  7. "context"
  8. "encoding/binary"
  9. "fmt"
  10. "io"
  11. "net"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "testing/iotest"
  17. "time"
  18. chp "golang.org/x/crypto/chacha20poly1305"
  19. "golang.org/x/net/nettest"
  20. "tailscale.com/net/memnet"
  21. "tailscale.com/types/key"
  22. )
  23. const testProtocolVersion = 1
  24. func TestMessageSize(t *testing.T) {
  25. // This test is a regression guard against someone looking at
  26. // maxCiphertextSize, going "huh, we could be more efficient if it
  27. // were larger, and accidentally violating the Noise spec. Do not
  28. // change this max value, it's a deliberate limitation of the
  29. // cryptographic protocol we use (see Section 3 "Message Format"
  30. // of the Noise spec).
  31. const max = 65535
  32. if maxCiphertextSize > max {
  33. t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
  34. }
  35. }
  36. func TestConnBasic(t *testing.T) {
  37. client, server := pair(t)
  38. sb := sinkReads(server)
  39. want := "test"
  40. if _, err := io.WriteString(client, want); err != nil {
  41. t.Fatalf("client write failed: %v", err)
  42. }
  43. client.Close()
  44. if got := sb.String(4); got != want {
  45. t.Fatalf("wrong content received: got %q, want %q", got, want)
  46. }
  47. if err := sb.Error(); err != io.EOF {
  48. t.Fatal("client close wasn't seen by server")
  49. }
  50. if sb.Total() != 4 {
  51. t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
  52. }
  53. }
  54. // bufferedWriteConn wraps a net.Conn and gives control over how
  55. // Writes get batched out.
  56. type bufferedWriteConn struct {
  57. net.Conn
  58. w *bufio.Writer
  59. manualFlush bool
  60. }
  61. func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
  62. n, err := c.w.Write(bs)
  63. if err == nil && !c.manualFlush {
  64. err = c.w.Flush()
  65. }
  66. return n, err
  67. }
  68. // TestFastPath exercises the Read codepath that can receive multiple
  69. // Noise frames at once and decode each in turn without making another
  70. // syscall.
  71. func TestFastPath(t *testing.T) {
  72. s1, s2 := memnet.NewConn("noise", 128000)
  73. b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
  74. client, server := pairWithConns(t, b, s2)
  75. b.manualFlush = true
  76. sb := sinkReads(server)
  77. const packets = 10
  78. s := "test"
  79. for range packets {
  80. // Many separate writes, to force separate Noise frames that
  81. // all get buffered up and then all sent as a single slice to
  82. // the server.
  83. if _, err := io.WriteString(client, s); err != nil {
  84. t.Fatalf("client write1 failed: %v", err)
  85. }
  86. }
  87. if err := b.w.Flush(); err != nil {
  88. t.Fatalf("client flush failed: %v", err)
  89. }
  90. client.Close()
  91. want := strings.Repeat(s, packets)
  92. if got := sb.String(len(want)); got != want {
  93. t.Fatalf("wrong content received: got %q, want %q", got, want)
  94. }
  95. if err := sb.Error(); err != io.EOF {
  96. t.Fatalf("client close wasn't seen by server")
  97. }
  98. }
  99. // Writes things larger than a single Noise frame, to check the
  100. // chunking on the encoder and decoder.
  101. func TestBigData(t *testing.T) {
  102. client, server := pair(t)
  103. serverReads := sinkReads(server)
  104. clientReads := sinkReads(client)
  105. const sz = 15 * 1024 // 15KiB
  106. clientStr := strings.Repeat("abcde", sz/5)
  107. serverStr := strings.Repeat("fghij", sz/5*2)
  108. if _, err := io.WriteString(client, clientStr); err != nil {
  109. t.Fatalf("writing client>server: %v", err)
  110. }
  111. if _, err := io.WriteString(server, serverStr); err != nil {
  112. t.Fatalf("writing server>client: %v", err)
  113. }
  114. if serverGot := serverReads.String(sz); serverGot != clientStr {
  115. t.Error("server didn't receive what client sent")
  116. }
  117. if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
  118. t.Error("client didn't receive what server sent")
  119. }
  120. getNonce := func(n [chp.NonceSize]byte) uint64 {
  121. if binary.BigEndian.Uint32(n[:4]) != 0 {
  122. panic("unexpected nonce")
  123. }
  124. return binary.BigEndian.Uint64(n[4:])
  125. }
  126. // Reach into the Conns and verify the cipher nonces advanced as
  127. // expected.
  128. if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
  129. t.Error("desynchronized client tx nonce")
  130. }
  131. if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
  132. t.Error("desynchronized server tx nonce")
  133. }
  134. if n := getNonce(client.tx.nonce); n != 4 {
  135. t.Errorf("wrong client tx nonce, got %d want 4", n)
  136. }
  137. if n := getNonce(server.tx.nonce); n != 8 {
  138. t.Errorf("wrong client tx nonce, got %d want 8", n)
  139. }
  140. }
  141. // readerConn wraps a net.Conn and routes its Reads through a separate
  142. // io.Reader.
  143. type readerConn struct {
  144. net.Conn
  145. r io.Reader
  146. }
  147. func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
  148. // Check that the receiver can handle not being able to read an entire
  149. // frame in a single syscall.
  150. func TestDataTrickle(t *testing.T) {
  151. s1, s2 := memnet.NewConn("noise", 128000)
  152. client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
  153. serverReads := sinkReads(server)
  154. const sz = 10000
  155. clientStr := strings.Repeat("abcde", sz/5)
  156. if _, err := io.WriteString(client, clientStr); err != nil {
  157. t.Fatalf("writing client>server: %v", err)
  158. }
  159. serverGot := serverReads.String(sz)
  160. if serverGot != clientStr {
  161. t.Error("server didn't receive what client sent")
  162. }
  163. }
  164. func TestConnStd(t *testing.T) {
  165. // You can run this test manually, and noise.Conn should pass all
  166. // of them except for TestConn/PastTimeout,
  167. // TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
  168. // those tests assume that write errors are recoverable, and
  169. // they're not on our Conn due to cipher security.
  170. t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
  171. nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
  172. s1, s2 := memnet.NewConn("noise", 4096)
  173. controlKey := key.NewMachine()
  174. machineKey := key.NewMachine()
  175. serverErr := make(chan error, 1)
  176. go func() {
  177. var err error
  178. c2, err = Server(context.Background(), s2, controlKey, nil)
  179. serverErr <- err
  180. }()
  181. c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
  182. if err != nil {
  183. s1.Close()
  184. s2.Close()
  185. return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
  186. }
  187. if err := <-serverErr; err != nil {
  188. c1.Close()
  189. s1.Close()
  190. s2.Close()
  191. return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
  192. }
  193. return c1, c2, func() {
  194. c1.Close()
  195. c2.Close()
  196. }, nil
  197. })
  198. }
  199. // tests that the idle memory overhead of a Conn blocked in a read is
  200. // reasonable (under 2K). It was previously over 8KB with two 4KB
  201. // buffers for rx/tx. This make sure we don't regress. Hopefully it
  202. // doesn't turn into a flaky test. If so, const max can be adjusted,
  203. // or it can be deleted or reworked.
  204. func TestConnMemoryOverhead(t *testing.T) {
  205. num := 1000
  206. if testing.Short() {
  207. num = 100
  208. }
  209. ng0 := runtime.NumGoroutine()
  210. runtime.GC()
  211. var ms0 runtime.MemStats
  212. runtime.ReadMemStats(&ms0)
  213. var closers []io.Closer
  214. closeAll := func() {
  215. for _, c := range closers {
  216. c.Close()
  217. }
  218. closers = nil
  219. }
  220. defer closeAll()
  221. for range num {
  222. client, server := pair(t)
  223. closers = append(closers, client, server)
  224. go func() {
  225. var buf [1]byte
  226. client.Read(buf[:])
  227. }()
  228. }
  229. t0 := time.Now()
  230. deadline := t0.Add(3 * time.Second)
  231. var ngo int
  232. for time.Now().Before(deadline) {
  233. runtime.GC()
  234. ngo = runtime.NumGoroutine()
  235. if ngo >= num {
  236. break
  237. }
  238. time.Sleep(10 * time.Millisecond)
  239. }
  240. if ngo < num {
  241. t.Fatalf("only %v goroutines; expected %v+", ngo, num)
  242. }
  243. runtime.GC()
  244. var ms runtime.MemStats
  245. runtime.ReadMemStats(&ms)
  246. growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc)
  247. growthEach := float64(growthTotal) / float64(num)
  248. t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach)
  249. const max = 2048
  250. if growthEach > max {
  251. t.Errorf("allocated more than expected; want max %v bytes/each", max)
  252. }
  253. closeAll()
  254. // And make sure our goroutines go away too.
  255. deadline = time.Now().Add(3 * time.Second)
  256. for time.Now().Before(deadline) {
  257. ngo = runtime.NumGoroutine()
  258. if ngo < ng0+num/10 {
  259. break
  260. }
  261. time.Sleep(10 * time.Millisecond)
  262. }
  263. if ngo >= ng0+num/10 {
  264. t.Errorf("goroutines didn't go back down; started at %v, now %v", ng0, ngo)
  265. }
  266. }
  267. type readSink struct {
  268. r io.Reader
  269. cond *sync.Cond
  270. sync.Mutex
  271. bs bytes.Buffer
  272. err error
  273. }
  274. func sinkReads(r io.Reader) *readSink {
  275. ret := &readSink{
  276. r: r,
  277. }
  278. ret.cond = sync.NewCond(&ret.Mutex)
  279. go func() {
  280. var buf [4096]byte
  281. for {
  282. n, err := r.Read(buf[:])
  283. ret.Lock()
  284. ret.bs.Write(buf[:n])
  285. if err != nil {
  286. ret.err = err
  287. }
  288. ret.cond.Broadcast()
  289. ret.Unlock()
  290. if err != nil {
  291. return
  292. }
  293. }
  294. }()
  295. return ret
  296. }
  297. func (s *readSink) String(total int) string {
  298. s.Lock()
  299. defer s.Unlock()
  300. for s.bs.Len() < total && s.err == nil {
  301. s.cond.Wait()
  302. }
  303. if s.err != nil {
  304. total = s.bs.Len()
  305. }
  306. return string(s.bs.Bytes()[:total])
  307. }
  308. func (s *readSink) Error() error {
  309. s.Lock()
  310. defer s.Unlock()
  311. for s.err == nil {
  312. s.cond.Wait()
  313. }
  314. return s.err
  315. }
  316. func (s *readSink) Total() int {
  317. s.Lock()
  318. defer s.Unlock()
  319. return s.bs.Len()
  320. }
  321. func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
  322. var (
  323. controlKey = key.NewMachine()
  324. machineKey = key.NewMachine()
  325. server *Conn
  326. serverErr = make(chan error, 1)
  327. )
  328. go func() {
  329. var err error
  330. server, err = Server(context.Background(), serverConn, controlKey, nil)
  331. serverErr <- err
  332. }()
  333. client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion)
  334. if err != nil {
  335. t.Fatalf("client connection failed: %v", err)
  336. }
  337. if err := <-serverErr; err != nil {
  338. t.Fatalf("server connection failed: %v", err)
  339. }
  340. return client, server
  341. }
  342. func pair(t *testing.T) (*Conn, *Conn) {
  343. s1, s2 := memnet.NewConn("noise", 128000)
  344. return pairWithConns(t, s1, s2)
  345. }