interop_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlbase
  4. import (
  5. "context"
  6. "encoding/binary"
  7. "errors"
  8. "io"
  9. "net"
  10. "testing"
  11. "tailscale.com/net/memnet"
  12. "tailscale.com/types/key"
  13. )
  14. // Can a reference Noise IK client talk to our server?
  15. func TestInteropClient(t *testing.T) {
  16. var (
  17. s1, s2 = memnet.NewConn("noise", 128000)
  18. controlKey = key.NewMachine()
  19. machineKey = key.NewMachine()
  20. serverErr = make(chan error, 2)
  21. serverBytes = make(chan []byte, 1)
  22. c2s = "client>server"
  23. s2c = "server>client"
  24. )
  25. go func() {
  26. server, err := Server(context.Background(), s2, controlKey, nil)
  27. serverErr <- err
  28. if err != nil {
  29. return
  30. }
  31. var buf [1024]byte
  32. _, err = io.ReadFull(server, buf[:len(c2s)])
  33. serverBytes <- buf[:len(c2s)]
  34. if err != nil {
  35. serverErr <- err
  36. return
  37. }
  38. _, err = server.Write([]byte(s2c))
  39. serverErr <- err
  40. }()
  41. gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
  42. if err != nil {
  43. t.Fatalf("failed client interop: %v", err)
  44. }
  45. if string(gotS2C) != s2c {
  46. t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
  47. }
  48. if err := <-serverErr; err != nil {
  49. t.Fatalf("server handshake failed: %v", err)
  50. }
  51. if err := <-serverErr; err != nil {
  52. t.Fatalf("server read/write failed: %v", err)
  53. }
  54. if got := string(<-serverBytes); got != c2s {
  55. t.Fatalf("server received %q, want %q", got, c2s)
  56. }
  57. }
  58. // Can our client talk to a reference Noise IK server?
  59. func TestInteropServer(t *testing.T) {
  60. var (
  61. s1, s2 = memnet.NewConn("noise", 128000)
  62. controlKey = key.NewMachine()
  63. machineKey = key.NewMachine()
  64. clientErr = make(chan error, 2)
  65. clientBytes = make(chan []byte, 1)
  66. c2s = "client>server"
  67. s2c = "server>client"
  68. )
  69. go func() {
  70. client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
  71. clientErr <- err
  72. if err != nil {
  73. return
  74. }
  75. _, err = client.Write([]byte(c2s))
  76. if err != nil {
  77. clientErr <- err
  78. return
  79. }
  80. var buf [1024]byte
  81. _, err = io.ReadFull(client, buf[:len(s2c)])
  82. clientBytes <- buf[:len(s2c)]
  83. clientErr <- err
  84. }()
  85. gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
  86. if err != nil {
  87. t.Fatalf("failed server interop: %v", err)
  88. }
  89. if string(gotC2S) != c2s {
  90. t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
  91. }
  92. if err := <-clientErr; err != nil {
  93. t.Fatalf("client handshake failed: %v", err)
  94. }
  95. if err := <-clientErr; err != nil {
  96. t.Fatalf("client read/write failed: %v", err)
  97. }
  98. if got := string(<-clientBytes); got != s2c {
  99. t.Fatalf("client received %q, want %q", got, s2c)
  100. }
  101. }
  102. // noiseExplorerClient uses the Noise Explorer implementation of Noise
  103. // IK to handshake as a Noise client on conn, transmit payload, and
  104. // read+return a payload from the peer.
  105. func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
  106. var mk keypair
  107. copy(mk.private_key[:], machineKey.UntypedBytes())
  108. copy(mk.public_key[:], machineKey.Public().UntypedBytes())
  109. var peerKey [32]byte
  110. copy(peerKey[:], controlKey.UntypedBytes())
  111. session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)
  112. _, msg1 := SendMessage(&session, nil)
  113. var hdr [initiationHeaderLen]byte
  114. binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
  115. hdr[2] = msgTypeInitiation
  116. binary.BigEndian.PutUint16(hdr[3:5], 96)
  117. if _, err := conn.Write(hdr[:]); err != nil {
  118. return nil, err
  119. }
  120. if _, err := conn.Write(msg1.ne[:]); err != nil {
  121. return nil, err
  122. }
  123. if _, err := conn.Write(msg1.ns); err != nil {
  124. return nil, err
  125. }
  126. if _, err := conn.Write(msg1.ciphertext); err != nil {
  127. return nil, err
  128. }
  129. var buf [1024]byte
  130. if _, err := io.ReadFull(conn, buf[:51]); err != nil {
  131. return nil, err
  132. }
  133. // ignore the header for this test, we're only checking the noise
  134. // implementation.
  135. msg2 := messagebuffer{
  136. ciphertext: buf[35:51],
  137. }
  138. copy(msg2.ne[:], buf[3:35])
  139. _, p, valid := RecvMessage(&session, &msg2)
  140. if !valid {
  141. return nil, errors.New("handshake failed")
  142. }
  143. if len(p) != 0 {
  144. return nil, errors.New("non-empty payload")
  145. }
  146. _, msg3 := SendMessage(&session, payload)
  147. hdr[0] = msgTypeRecord
  148. binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
  149. if _, err := conn.Write(hdr[:3]); err != nil {
  150. return nil, err
  151. }
  152. if _, err := conn.Write(msg3.ciphertext); err != nil {
  153. return nil, err
  154. }
  155. if _, err := io.ReadFull(conn, buf[:3]); err != nil {
  156. return nil, err
  157. }
  158. // Ignore all of the header except the payload length
  159. plen := int(binary.BigEndian.Uint16(buf[1:3]))
  160. if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
  161. return nil, err
  162. }
  163. msg4 := messagebuffer{
  164. ciphertext: buf[:plen],
  165. }
  166. _, p, valid = RecvMessage(&session, &msg4)
  167. if !valid {
  168. return nil, errors.New("transport message decryption failed")
  169. }
  170. return p, nil
  171. }
  172. func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
  173. var mk keypair
  174. copy(mk.private_key[:], controlKey.UntypedBytes())
  175. copy(mk.public_key[:], controlKey.Public().UntypedBytes())
  176. session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})
  177. var buf [1024]byte
  178. if _, err := io.ReadFull(conn, buf[:101]); err != nil {
  179. return nil, err
  180. }
  181. // Ignore the header, we're just checking the noise implementation.
  182. msg1 := messagebuffer{
  183. ns: buf[37:85],
  184. ciphertext: buf[85:101],
  185. }
  186. copy(msg1.ne[:], buf[5:37])
  187. _, p, valid := RecvMessage(&session, &msg1)
  188. if !valid {
  189. return nil, errors.New("handshake failed")
  190. }
  191. if len(p) != 0 {
  192. return nil, errors.New("non-empty payload")
  193. }
  194. _, msg2 := SendMessage(&session, nil)
  195. var hdr [headerLen]byte
  196. hdr[0] = msgTypeResponse
  197. binary.BigEndian.PutUint16(hdr[1:3], 48)
  198. if _, err := conn.Write(hdr[:]); err != nil {
  199. return nil, err
  200. }
  201. if _, err := conn.Write(msg2.ne[:]); err != nil {
  202. return nil, err
  203. }
  204. if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
  205. return nil, err
  206. }
  207. if _, err := io.ReadFull(conn, buf[:3]); err != nil {
  208. return nil, err
  209. }
  210. plen := int(binary.BigEndian.Uint16(buf[1:3]))
  211. if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
  212. return nil, err
  213. }
  214. msg3 := messagebuffer{
  215. ciphertext: buf[:plen],
  216. }
  217. _, p, valid = RecvMessage(&session, &msg3)
  218. if !valid {
  219. return nil, errors.New("transport message decryption failed")
  220. }
  221. _, msg4 := SendMessage(&session, payload)
  222. hdr[0] = msgTypeRecord
  223. binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
  224. if _, err := conn.Write(hdr[:]); err != nil {
  225. return nil, err
  226. }
  227. if _, err := conn.Write(msg4.ciphertext); err != nil {
  228. return nil, err
  229. }
  230. return p, nil
  231. }