handshake_test.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlbase
  4. import (
  5. "bytes"
  6. "context"
  7. "io"
  8. "strings"
  9. "testing"
  10. "time"
  11. "tailscale.com/net/memnet"
  12. "tailscale.com/types/key"
  13. )
  14. func TestHandshake(t *testing.T) {
  15. var (
  16. clientConn, serverConn = memnet.NewConn("noise", 128000)
  17. serverKey = key.NewMachine()
  18. clientKey = key.NewMachine()
  19. server *Conn
  20. serverErr = make(chan error, 1)
  21. )
  22. go func() {
  23. var err error
  24. server, err = Server(context.Background(), serverConn, serverKey, nil)
  25. serverErr <- err
  26. }()
  27. client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  28. if err != nil {
  29. t.Fatalf("client connection failed: %v", err)
  30. }
  31. if err := <-serverErr; err != nil {
  32. t.Fatalf("server connection failed: %v", err)
  33. }
  34. if client.HandshakeHash() != server.HandshakeHash() {
  35. t.Fatal("client and server disagree on handshake hash")
  36. }
  37. if client.ProtocolVersion() != int(testProtocolVersion) {
  38. t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), testProtocolVersion)
  39. }
  40. if client.ProtocolVersion() != server.ProtocolVersion() {
  41. t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
  42. }
  43. if client.Peer() != serverKey.Public() {
  44. t.Fatal("client peer key isn't serverKey")
  45. }
  46. if server.Peer() != clientKey.Public() {
  47. t.Fatal("client peer key isn't serverKey")
  48. }
  49. }
  50. // Check that handshaking repeatedly with the same long-term keys
  51. // result in different handshake hashes and wire traffic.
  52. func TestNoReuse(t *testing.T) {
  53. var (
  54. hashes = map[[32]byte]bool{}
  55. clientHandshakes = map[[96]byte]bool{}
  56. serverHandshakes = map[[48]byte]bool{}
  57. packets = map[[32]byte]bool{}
  58. )
  59. for i := 0; i < 10; i++ {
  60. var (
  61. clientRaw, serverRaw = memnet.NewConn("noise", 128000)
  62. clientBuf, serverBuf bytes.Buffer
  63. clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)}
  64. serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)}
  65. serverKey = key.NewMachine()
  66. clientKey = key.NewMachine()
  67. server *Conn
  68. serverErr = make(chan error, 1)
  69. )
  70. go func() {
  71. var err error
  72. server, err = Server(context.Background(), serverConn, serverKey, nil)
  73. serverErr <- err
  74. }()
  75. client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  76. if err != nil {
  77. t.Fatalf("client connection failed: %v", err)
  78. }
  79. if err := <-serverErr; err != nil {
  80. t.Fatalf("server connection failed: %v", err)
  81. }
  82. var clientHS [96]byte
  83. copy(clientHS[:], serverBuf.Bytes())
  84. if clientHandshakes[clientHS] {
  85. t.Fatal("client handshake seen twice")
  86. }
  87. clientHandshakes[clientHS] = true
  88. var serverHS [48]byte
  89. copy(serverHS[:], clientBuf.Bytes())
  90. if serverHandshakes[serverHS] {
  91. t.Fatal("server handshake seen twice")
  92. }
  93. serverHandshakes[serverHS] = true
  94. clientBuf.Reset()
  95. serverBuf.Reset()
  96. cb := sinkReads(client)
  97. sb := sinkReads(server)
  98. if hashes[client.HandshakeHash()] {
  99. t.Fatalf("handshake hash %v seen twice", client.HandshakeHash())
  100. }
  101. hashes[client.HandshakeHash()] = true
  102. // Sending 14 bytes turns into 32 bytes on the wire (+16 for
  103. // the chacha20poly1305 overhead, +2 length header)
  104. if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
  105. t.Fatalf("client>server write failed: %v", err)
  106. }
  107. if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil {
  108. t.Fatalf("server>client write failed: %v", err)
  109. }
  110. // Wait for the bytes to be read, so we know they've traveled end to end
  111. cb.String(14)
  112. sb.String(14)
  113. var clientWire, serverWire [32]byte
  114. copy(clientWire[:], clientBuf.Bytes())
  115. copy(serverWire[:], serverBuf.Bytes())
  116. if packets[clientWire] {
  117. t.Fatalf("client wire traffic seen twice")
  118. }
  119. packets[clientWire] = true
  120. if packets[serverWire] {
  121. t.Fatalf("server wire traffic seen twice")
  122. }
  123. packets[serverWire] = true
  124. server.Close()
  125. client.Close()
  126. }
  127. }
  128. // tamperReader wraps a reader and mutates the Nth byte.
  129. type tamperReader struct {
  130. r io.Reader
  131. n int
  132. total int
  133. }
  134. func (r *tamperReader) Read(bs []byte) (int, error) {
  135. n, err := r.r.Read(bs)
  136. if off := r.n - r.total; off >= 0 && off < n {
  137. bs[off] += 1
  138. }
  139. r.total += n
  140. return n, err
  141. }
  142. func TestTampering(t *testing.T) {
  143. // Tamper with every byte of the client initiation message.
  144. for i := 0; i < 101; i++ {
  145. var (
  146. clientConn, serverRaw = memnet.NewConn("noise", 128000)
  147. serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
  148. serverKey = key.NewMachine()
  149. clientKey = key.NewMachine()
  150. serverErr = make(chan error, 1)
  151. )
  152. go func() {
  153. _, err := Server(context.Background(), serverConn, serverKey, nil)
  154. // If the server failed, we have to close the Conn to
  155. // unblock the client.
  156. if err != nil {
  157. serverConn.Close()
  158. }
  159. serverErr <- err
  160. }()
  161. _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  162. if err == nil {
  163. t.Fatal("client connection succeeded despite tampering")
  164. }
  165. if err := <-serverErr; err == nil {
  166. t.Fatalf("server connection succeeded despite tampering")
  167. }
  168. }
  169. // Tamper with every byte of the server response message.
  170. for i := 0; i < 51; i++ {
  171. var (
  172. clientRaw, serverConn = memnet.NewConn("noise", 128000)
  173. clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
  174. serverKey = key.NewMachine()
  175. clientKey = key.NewMachine()
  176. serverErr = make(chan error, 1)
  177. )
  178. go func() {
  179. _, err := Server(context.Background(), serverConn, serverKey, nil)
  180. serverErr <- err
  181. }()
  182. _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  183. if err == nil {
  184. t.Fatal("client connection succeeded despite tampering")
  185. }
  186. // The server shouldn't fail, because the tampering took place
  187. // in its response.
  188. if err := <-serverErr; err != nil {
  189. t.Fatalf("server connection failed despite no tampering: %v", err)
  190. }
  191. }
  192. // Tamper with every byte of the first server>client transport message.
  193. for i := 0; i < 30; i++ {
  194. var (
  195. clientRaw, serverConn = memnet.NewConn("noise", 128000)
  196. clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 51 + i, 0}}
  197. serverKey = key.NewMachine()
  198. clientKey = key.NewMachine()
  199. serverErr = make(chan error, 1)
  200. )
  201. go func() {
  202. server, err := Server(context.Background(), serverConn, serverKey, nil)
  203. serverErr <- err
  204. _, err = io.WriteString(server, strings.Repeat("a", 14))
  205. serverErr <- err
  206. }()
  207. client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  208. if err != nil {
  209. t.Fatalf("client handshake failed: %v", err)
  210. }
  211. // The server shouldn't fail, because the tampering took place
  212. // in its response.
  213. if err := <-serverErr; err != nil {
  214. t.Fatalf("server handshake failed: %v", err)
  215. }
  216. // The client needs a timeout if the tampering is hitting the length header.
  217. if i == 1 || i == 2 {
  218. client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
  219. }
  220. var bs [100]byte
  221. n, err := client.Read(bs[:])
  222. if err == nil {
  223. t.Fatal("read succeeded despite tampering")
  224. }
  225. if n != 0 {
  226. t.Fatal("conn yielded some bytes despite tampering")
  227. }
  228. }
  229. // Tamper with every byte of the first client>server transport message.
  230. for i := 0; i < 30; i++ {
  231. var (
  232. clientConn, serverRaw = memnet.NewConn("noise", 128000)
  233. serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}}
  234. serverKey = key.NewMachine()
  235. clientKey = key.NewMachine()
  236. serverErr = make(chan error, 1)
  237. )
  238. go func() {
  239. server, err := Server(context.Background(), serverConn, serverKey, nil)
  240. serverErr <- err
  241. var bs [100]byte
  242. // The server needs a timeout if the tampering is hitting the length header.
  243. if i == 1 || i == 2 {
  244. server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
  245. }
  246. n, err := server.Read(bs[:])
  247. if n != 0 {
  248. panic("server got bytes despite tampering")
  249. } else {
  250. serverErr <- err
  251. }
  252. }()
  253. client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
  254. if err != nil {
  255. t.Fatalf("client handshake failed: %v", err)
  256. }
  257. if err := <-serverErr; err != nil {
  258. t.Fatalf("server handshake failed: %v", err)
  259. }
  260. if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
  261. t.Fatalf("client>server write failed: %v", err)
  262. }
  263. if err := <-serverErr; err == nil {
  264. t.Fatal("server successfully received bytes despite tampering")
  265. }
  266. }
  267. }