handshake.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlbase
  4. import (
  5. "context"
  6. "crypto/cipher"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "hash"
  11. "io"
  12. "net"
  13. "strconv"
  14. "time"
  15. "go4.org/mem"
  16. "golang.org/x/crypto/blake2s"
  17. chp "golang.org/x/crypto/chacha20poly1305"
  18. "golang.org/x/crypto/curve25519"
  19. "golang.org/x/crypto/hkdf"
  20. "tailscale.com/types/key"
  21. )
  22. const (
  23. // protocolName is the name of the specific instantiation of Noise
  24. // that the control protocol uses. This string's value is fixed by
  25. // the Noise spec, and shouldn't be changed unless we're updating
  26. // the control protocol to use a different Noise instance.
  27. protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
  28. // protocolVersion is the version of the control protocol that
  29. // Client will use when initiating a handshake.
  30. //protocolVersion uint16 = 1
  31. // protocolVersionPrefix is the name portion of the protocol
  32. // name+version string that gets mixed into the handshake as a
  33. // prologue.
  34. //
  35. // This mixing verifies that both clients agree that they're
  36. // executing the control protocol at a specific version that
  37. // matches the advertised version in the cleartext packet header.
  38. protocolVersionPrefix = "Tailscale Control Protocol v"
  39. invalidNonce = ^uint64(0)
  40. )
  41. func protocolVersionPrologue(version uint16) []byte {
  42. ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
  43. ret = append(ret, protocolVersionPrefix...)
  44. return strconv.AppendUint(ret, uint64(version), 10)
  45. }
  46. // HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
  47. // is assumed to have already sent the client>server handshake
  48. // initiation message.
  49. type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
  50. // ClientDeferred initiates a control client handshake, returning the
  51. // initial message to send to the server and a continuation to
  52. // finalize the handshake.
  53. //
  54. // ClientDeferred is split in this way for RTT reduction: we run this
  55. // protocol after negotiating a protocol switch from HTTP/HTTPS. If we
  56. // completely serialized the negotiation followed by the handshake,
  57. // we'd pay an extra RTT to transmit the handshake initiation after
  58. // protocol switching. By splitting the handshake into an initial
  59. // message and a continuation, we can embed the handshake initiation
  60. // into the HTTP protocol switching request and avoid a bit of delay.
  61. func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
  62. var s symmetricState
  63. s.Initialize()
  64. // prologue
  65. s.MixHash(protocolVersionPrologue(protocolVersion))
  66. // <- s
  67. // ...
  68. s.MixHash(controlKey.UntypedBytes())
  69. // -> e, es, s, ss
  70. init := mkInitiationMessage(protocolVersion)
  71. machineEphemeral := key.NewMachine()
  72. machineEphemeralPub := machineEphemeral.Public()
  73. copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
  74. s.MixHash(machineEphemeralPub.UntypedBytes())
  75. cipher, err := s.MixDH(machineEphemeral, controlKey)
  76. if err != nil {
  77. return nil, nil, fmt.Errorf("computing es: %w", err)
  78. }
  79. machineKeyPub := machineKey.Public()
  80. s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
  81. cipher, err = s.MixDH(machineKey, controlKey)
  82. if err != nil {
  83. return nil, nil, fmt.Errorf("computing ss: %w", err)
  84. }
  85. s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
  86. cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
  87. return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
  88. }
  89. return init[:], cont, nil
  90. }
  91. // Client wraps ClientDeferred and immediately invokes the returned
  92. // continuation with conn.
  93. //
  94. // This is a helper for when you don't need the fancy
  95. // continuation-style handshake, and just want to synchronously
  96. // upgrade a net.Conn to a secure transport.
  97. func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
  98. init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
  99. if err != nil {
  100. return nil, err
  101. }
  102. if _, err := conn.Write(init); err != nil {
  103. return nil, err
  104. }
  105. return cont(ctx, conn)
  106. }
  107. func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
  108. // No matter what, this function can only run once per s. Ensure
  109. // attempted reuse causes a panic.
  110. defer func() {
  111. s.finished = true
  112. }()
  113. if deadline, ok := ctx.Deadline(); ok {
  114. if err := conn.SetDeadline(deadline); err != nil {
  115. return nil, fmt.Errorf("setting conn deadline: %w", err)
  116. }
  117. defer func() {
  118. conn.SetDeadline(time.Time{})
  119. }()
  120. }
  121. // Read in the payload and look for errors/protocol violations from the server.
  122. var resp responseMessage
  123. if _, err := io.ReadFull(conn, resp.Header()); err != nil {
  124. return nil, fmt.Errorf("reading response header: %w", err)
  125. }
  126. if resp.Type() != msgTypeResponse {
  127. if resp.Type() != msgTypeError {
  128. return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
  129. }
  130. msg := make([]byte, resp.Length())
  131. if _, err := io.ReadFull(conn, msg); err != nil {
  132. return nil, err
  133. }
  134. return nil, fmt.Errorf("server error: %q", msg)
  135. }
  136. if resp.Length() != len(resp.Payload()) {
  137. return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
  138. }
  139. if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
  140. return nil, err
  141. }
  142. // <- e, ee, se
  143. controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
  144. s.MixHash(controlEphemeralPub.UntypedBytes())
  145. if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
  146. return nil, fmt.Errorf("computing ee: %w", err)
  147. }
  148. cipher, err := s.MixDH(machineKey, controlEphemeralPub)
  149. if err != nil {
  150. return nil, fmt.Errorf("computing se: %w", err)
  151. }
  152. if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
  153. return nil, fmt.Errorf("decrypting payload: %w", err)
  154. }
  155. c1, c2, err := s.Split()
  156. if err != nil {
  157. return nil, fmt.Errorf("finalizing handshake: %w", err)
  158. }
  159. c := &Conn{
  160. conn: conn,
  161. version: protocolVersion,
  162. peer: controlKey,
  163. handshakeHash: s.h,
  164. tx: txState{
  165. cipher: c1,
  166. },
  167. rx: rxState{
  168. cipher: c2,
  169. },
  170. }
  171. return c, nil
  172. }
  173. // Server initiates a control server handshake, returning the resulting
  174. // control connection.
  175. //
  176. // optionalInit can be the client's initial handshake message as
  177. // returned by ClientDeferred, or nil in which case the initial
  178. // message is read from conn.
  179. //
  180. // The context deadline, if any, covers the entire handshaking
  181. // process.
  182. func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
  183. if deadline, ok := ctx.Deadline(); ok {
  184. if err := conn.SetDeadline(deadline); err != nil {
  185. return nil, fmt.Errorf("setting conn deadline: %w", err)
  186. }
  187. defer func() {
  188. conn.SetDeadline(time.Time{})
  189. }()
  190. }
  191. // Deliberately does not support formatting, so that we don't echo
  192. // attacker-controlled input back to them.
  193. sendErr := func(msg string) error {
  194. if len(msg) >= 1<<16 {
  195. msg = msg[:1<<16]
  196. }
  197. var hdr [headerLen]byte
  198. hdr[0] = msgTypeError
  199. binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
  200. if _, err := conn.Write(hdr[:]); err != nil {
  201. return fmt.Errorf("sending %q error to client: %w", msg, err)
  202. }
  203. if _, err := io.WriteString(conn, msg); err != nil {
  204. return fmt.Errorf("sending %q error to client: %w", msg, err)
  205. }
  206. return fmt.Errorf("refused client handshake: %q", msg)
  207. }
  208. var s symmetricState
  209. s.Initialize()
  210. var init initiationMessage
  211. if optionalInit != nil {
  212. if len(optionalInit) != len(init) {
  213. return nil, sendErr("wrong handshake initiation size")
  214. }
  215. copy(init[:], optionalInit)
  216. } else if _, err := io.ReadFull(conn, init.Header()); err != nil {
  217. return nil, err
  218. }
  219. // Just a rename to make it more obvious what the value is. In the
  220. // current implementation we don't need to block any protocol
  221. // versions at this layer, it's safe to let the handshake proceed
  222. // and then let the caller make decisions based on the agreed-upon
  223. // protocol version.
  224. clientVersion := init.Version()
  225. if init.Type() != msgTypeInitiation {
  226. return nil, sendErr("unexpected handshake message type")
  227. }
  228. if init.Length() != len(init.Payload()) {
  229. return nil, sendErr("wrong handshake initiation length")
  230. }
  231. // if optionalInit was provided, we have the payload already.
  232. if optionalInit == nil {
  233. if _, err := io.ReadFull(conn, init.Payload()); err != nil {
  234. return nil, err
  235. }
  236. }
  237. // prologue. Can only do this once we at least think the client is
  238. // handshaking using a supported version.
  239. s.MixHash(protocolVersionPrologue(clientVersion))
  240. // <- s
  241. // ...
  242. controlKeyPub := controlKey.Public()
  243. s.MixHash(controlKeyPub.UntypedBytes())
  244. // -> e, es, s, ss
  245. machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
  246. s.MixHash(machineEphemeralPub.UntypedBytes())
  247. cipher, err := s.MixDH(controlKey, machineEphemeralPub)
  248. if err != nil {
  249. return nil, fmt.Errorf("computing es: %w", err)
  250. }
  251. var machineKeyBytes [32]byte
  252. if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
  253. return nil, fmt.Errorf("decrypting machine key: %w", err)
  254. }
  255. machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
  256. cipher, err = s.MixDH(controlKey, machineKey)
  257. if err != nil {
  258. return nil, fmt.Errorf("computing ss: %w", err)
  259. }
  260. if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
  261. return nil, fmt.Errorf("decrypting initiation tag: %w", err)
  262. }
  263. // <- e, ee, se
  264. resp := mkResponseMessage()
  265. controlEphemeral := key.NewMachine()
  266. controlEphemeralPub := controlEphemeral.Public()
  267. copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
  268. s.MixHash(controlEphemeralPub.UntypedBytes())
  269. if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
  270. return nil, fmt.Errorf("computing ee: %w", err)
  271. }
  272. cipher, err = s.MixDH(controlEphemeral, machineKey)
  273. if err != nil {
  274. return nil, fmt.Errorf("computing se: %w", err)
  275. }
  276. s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
  277. c1, c2, err := s.Split()
  278. if err != nil {
  279. return nil, fmt.Errorf("finalizing handshake: %w", err)
  280. }
  281. if _, err := conn.Write(resp[:]); err != nil {
  282. return nil, err
  283. }
  284. c := &Conn{
  285. conn: conn,
  286. version: clientVersion,
  287. peer: machineKey,
  288. handshakeHash: s.h,
  289. tx: txState{
  290. cipher: c2,
  291. },
  292. rx: rxState{
  293. cipher: c1,
  294. },
  295. }
  296. return c, nil
  297. }
  298. // symmetricState contains the state of an in-flight handshake.
  299. type symmetricState struct {
  300. finished bool
  301. h [blake2s.Size]byte // hash of currently-processed handshake state
  302. ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
  303. }
  304. func (s *symmetricState) checkFinished() {
  305. if s.finished {
  306. panic("attempted to use symmetricState after Split was called")
  307. }
  308. }
  309. // Initialize sets s to the initial handshake state, prior to
  310. // processing any handshake messages.
  311. func (s *symmetricState) Initialize() {
  312. s.checkFinished()
  313. s.h = blake2s.Sum256([]byte(protocolName))
  314. s.ck = s.h
  315. }
  316. // MixHash updates s.h to be BLAKE2s(s.h || data), where || is
  317. // concatenation.
  318. func (s *symmetricState) MixHash(data []byte) {
  319. s.checkFinished()
  320. h := newBLAKE2s()
  321. h.Write(s.h[:])
  322. h.Write(data)
  323. h.Sum(s.h[:0])
  324. }
  325. // MixDH updates s.ck with the result of X25519(priv, pub) and returns
  326. // a singleUseCHP that can be used to encrypt or decrypt handshake
  327. // data.
  328. //
  329. // MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
  330. // it as a single function allows for strongly-typed arguments that
  331. // reduce the risk of error in the caller (e.g. invoking X25519 with
  332. // two private keys, or two public keys), and thus producing the wrong
  333. // calculation.
  334. func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
  335. s.checkFinished()
  336. keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
  337. if err != nil {
  338. return nil, fmt.Errorf("computing X25519: %w", err)
  339. }
  340. r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
  341. if _, err := io.ReadFull(r, s.ck[:]); err != nil {
  342. return nil, fmt.Errorf("extracting ck: %w", err)
  343. }
  344. var k [chp.KeySize]byte
  345. if _, err := io.ReadFull(r, k[:]); err != nil {
  346. return nil, fmt.Errorf("extracting k: %w", err)
  347. }
  348. return newSingleUseCHP(k), nil
  349. }
  350. // EncryptAndHash encrypts plaintext into ciphertext (which must be
  351. // the correct size to hold the encrypted plaintext) using cipher,
  352. // mixes the ciphertext into s.h, and returns the ciphertext.
  353. func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
  354. s.checkFinished()
  355. if len(ciphertext) != len(plaintext)+chp.Overhead {
  356. panic("ciphertext is wrong size for given plaintext")
  357. }
  358. ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
  359. s.MixHash(ret)
  360. }
  361. // DecryptAndHash decrypts the given ciphertext into plaintext (which
  362. // must be the correct size to hold the decrypted ciphertext) using
  363. // cipher. If decryption is successful, it mixes the ciphertext into
  364. // s.h.
  365. func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
  366. s.checkFinished()
  367. if len(ciphertext) != len(plaintext)+chp.Overhead {
  368. return errors.New("plaintext is wrong size for given ciphertext")
  369. }
  370. if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
  371. return err
  372. }
  373. s.MixHash(ciphertext)
  374. return nil
  375. }
  376. // Split returns two ChaCha20Poly1305 ciphers with keys derived from
  377. // the current handshake state. Methods on s cannot be used again
  378. // after calling Split.
  379. func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
  380. s.finished = true
  381. var k1, k2 [chp.KeySize]byte
  382. r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
  383. if _, err := io.ReadFull(r, k1[:]); err != nil {
  384. return nil, nil, fmt.Errorf("extracting k1: %w", err)
  385. }
  386. if _, err := io.ReadFull(r, k2[:]); err != nil {
  387. return nil, nil, fmt.Errorf("extracting k2: %w", err)
  388. }
  389. c1, err = chp.New(k1[:])
  390. if err != nil {
  391. return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
  392. }
  393. c2, err = chp.New(k2[:])
  394. if err != nil {
  395. return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
  396. }
  397. return c1, c2, nil
  398. }
  399. // newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
  400. // error.
  401. func newBLAKE2s() hash.Hash {
  402. h, err := blake2s.New256(nil)
  403. if err != nil {
  404. // Should never happen, errors only happen when using BLAKE2s
  405. // in MAC mode with a key.
  406. panic(err)
  407. }
  408. return h
  409. }
  410. // newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
  411. // panics on error.
  412. func newCHP(key [chp.KeySize]byte) cipher.AEAD {
  413. aead, err := chp.New(key[:])
  414. if err != nil {
  415. // Can only happen if we passed a key of the wrong length. The
  416. // function signature prevents that.
  417. panic(err)
  418. }
  419. return aead
  420. }
  421. // singleUseCHP is an instance of ChaCha20Poly1305 that can be used
  422. // only once, either for encrypting or decrypting, but not both. The
  423. // chosen operation is always executed with an all-zeros
  424. // nonce. Subsequent calls to either Seal or Open panic.
  425. type singleUseCHP struct {
  426. c cipher.AEAD
  427. }
  428. func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
  429. return &singleUseCHP{newCHP(key)}
  430. }
  431. func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
  432. if c.c == nil {
  433. panic("Attempted reuse of singleUseAEAD")
  434. }
  435. cipher := c.c
  436. c.c = nil
  437. var nonce [chp.NonceSize]byte
  438. return cipher.Seal(dst, nonce[:], plaintext, additionalData)
  439. }
  440. func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
  441. if c.c == nil {
  442. panic("Attempted reuse of singleUseAEAD")
  443. }
  444. cipher := c.c
  445. c.c = nil
  446. var nonce [chp.NonceSize]byte
  447. return cipher.Open(dst, nonce[:], ciphertext, additionalData)
  448. }