Explorar el Código

control/noise: implement the base transport for the 2021 control protocol.

Signed-off-by: David Anderson <[email protected]>
David Anderson hace 4 años
padre
commit
da7544bcc5

+ 330 - 0
control/noise/conn.go

@@ -0,0 +1,330 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package noise implements the base transport of the Tailscale 2021
+// control protocol.
+//
+// The base transport implements Noise IK, instantiated with
+// Curve25519, ChaCha20Poly1305 and BLAKE2s.
+package noise
+
+import (
+	"crypto/cipher"
+	"encoding/binary"
+	"fmt"
+	"net"
+	"sync"
+	"time"
+
+	"golang.org/x/crypto/blake2s"
+	chp "golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/crypto/poly1305"
+	"tailscale.com/types/key"
+)
+
+const (
+	maxPlaintextSize  = 4096
+	maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
+	maxPacketSize     = maxCiphertextSize + 2 // ciphertext + length header
+)
+
+// A Conn is a secured Noise connection. It implements the net.Conn
+// interface, with the unusual trait that any write error (including a
+// SetWriteDeadline induced i/o timeout) cause all future writes to
+// fail.
+type Conn struct {
+	conn          net.Conn
+	peer          key.Public
+	handshakeHash [blake2s.Size]byte
+	rx            rxState
+	tx            txState
+}
+
+// rxState is all the Conn state that Read uses.
+type rxState struct {
+	sync.Mutex
+	cipher    cipher.AEAD
+	nonce     [chp.NonceSize]byte
+	buf       [maxPacketSize]byte
+	n         int    // number of valid bytes in buf
+	next      int    // offset of next undecrypted packet
+	plaintext []byte // slice into buf of decrypted bytes
+}
+
+// txState is all the Conn state that Write uses.
+type txState struct {
+	sync.Mutex
+	cipher cipher.AEAD
+	nonce  [chp.NonceSize]byte
+	buf    [maxPacketSize]byte
+	err    error // records the first partial write error for all future calls
+}
+
+// HandshakeHash returns the Noise handshake hash for the connection,
+// which can be used to bind other messages to this connection
+// (i.e. to ensure that the message wasn't replayed from a different
+// connection).
+func (c *Conn) HandshakeHash() [blake2s.Size]byte {
+	return c.handshakeHash
+}
+
+// Peer returns the peer's long-term public key.
+func (c *Conn) Peer() key.Public {
+	return c.peer
+}
+
+// validNonce reports whether nonce is in the valid range for use: 0
+// through 2^64-2.
+func validNonce(nonce []byte) bool {
+	return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
+}
+
+// readNLocked reads into c.rxBuf until rxBuf contains at least total
+// bytes. Returns a slice of the available bytes in rxBuf, or an
+// error if fewer than total bytes are available.
+func (c *Conn) readNLocked(total int) ([]byte, error) {
+	if total > maxPacketSize {
+		return nil, errReadTooBig{total}
+	}
+	for {
+		if total <= c.rx.n {
+			return c.rx.buf[:c.rx.n], nil
+		}
+
+		n, err := c.conn.Read(c.rx.buf[c.rx.n:])
+		c.rx.n += n
+		if err != nil {
+			return nil, err
+		}
+	}
+}
+
+// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
+// to the decrypted bytes. Returns an error if the cipher is exhausted
+// (i.e. can no longer be used safely) or decryption fails.
+func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
+	if !validNonce(c.rx.nonce[:]) {
+		return errCipherExhausted{}
+	}
+
+	c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
+
+	// Safe to increment the nonce here, because we checked for nonce
+	// wraparound above.
+	binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
+
+	if err != nil {
+		// Once a decryption has failed, our Conn is no longer
+		// synchronized with our peer. Nuke the cipher state to be
+		// safe, so that no further decryptions are attempted.
+		c.rx.cipher = nil
+	}
+	return err
+}
+
+// encryptLocked encrypts plaintext into c.tx.buf (including the
+// 2-byte length header) and returns a slice of the ciphertext, or an
+// error if the cipher is exhausted (i.e. can no longer be used safely).
+func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
+	if !validNonce(c.tx.nonce[:]) {
+		// Received 2^64-1 messages on this cipher state. Connection
+		// is no longer usable.
+		return nil, errCipherExhausted{}
+	}
+
+	binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
+	ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
+
+	// Safe to increment the nonce here, because we checked for nonce
+	// wraparound above.
+	binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
+
+	return ret, nil
+}
+
+// wholeCiphertextLocked returns a slice of one whole Noise frame from
+// c.rx.buf, if one whole ciphertext is available, and advances the
+// read state to the next Noise frame in the buffer. Returns nil
+// without advancing read state if there's not one whole ciphertext in
+// c.rx.buf.
+func (c *Conn) wholeCiphertextLocked() []byte {
+	available := c.rx.n - c.rx.next
+	if available < 2 {
+		return nil
+	}
+	bs := c.rx.buf[c.rx.next:c.rx.n]
+	totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
+	if len(bs) < totalSize {
+		return nil
+	}
+	c.rx.next += totalSize
+	return bs[:totalSize]
+}
+
+// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
+// and sets c.rx.plaintext to point to the decrypted
+// bytes. c.rx.plaintext is only valid if err == nil.
+func (c *Conn) decryptOneLocked() error {
+	c.rx.plaintext = nil
+
+	// Fast path: do we have one whole ciphertext frame buffered
+	// already?
+	if bs := c.wholeCiphertextLocked(); bs != nil {
+		return c.decryptLocked(bs[2:])
+	}
+
+	if c.rx.next != 0 {
+		// To simplify the read logic, move the remainder of the
+		// buffered bytes back to the head of the buffer, so we can
+		// grow it without worrying about wraparound.
+		copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
+		c.rx.n -= c.rx.next
+		c.rx.next = 0
+	}
+
+	bs, err := c.readNLocked(2)
+	if err != nil {
+		return err
+	}
+	totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
+	bs, err = c.readNLocked(totalLen)
+	if err != nil {
+		return err
+	}
+
+	c.rx.next = totalLen
+	bs = bs[2:totalLen]
+
+	return c.decryptLocked(bs)
+}
+
+// Read implements io.Reader.
+func (c *Conn) Read(bs []byte) (int, error) {
+	c.rx.Lock()
+	defer c.rx.Unlock()
+
+	if c.rx.cipher == nil {
+		return 0, net.ErrClosed
+	}
+	// Loop to handle receiving a zero-byte Noise message. Just skip
+	// over it and keep decrypting until we find some bytes.
+	for len(c.rx.plaintext) == 0 {
+		if err := c.decryptOneLocked(); err != nil {
+			return 0, err
+		}
+	}
+	n := copy(bs, c.rx.plaintext)
+	c.rx.plaintext = c.rx.plaintext[n:]
+	return n, nil
+}
+
+// Write implements io.Writer.
+func (c *Conn) Write(bs []byte) (n int, err error) {
+	c.tx.Lock()
+	defer c.tx.Unlock()
+
+	if c.tx.err != nil {
+		return 0, c.tx.err
+	}
+	defer func() {
+		if err != nil {
+			// All write errors are fatal for this conn, so clear the
+			// cipher state whenever an error happens.
+			c.tx.cipher = nil
+		}
+		if c.tx.err == nil {
+			// Only set c.tx.err if not nil so that we can return one
+			// error on the first failure, and a different one for
+			// subsequent calls. See the error handling around Write
+			// below for why.
+			c.tx.err = err
+		}
+	}()
+
+	if c.tx.cipher == nil {
+		return 0, net.ErrClosed
+	}
+
+	var sent int
+	for len(bs) > 0 {
+		toSend := bs
+		if len(toSend) > maxPlaintextSize {
+			toSend = bs[:maxPlaintextSize]
+		}
+		bs = bs[len(toSend):]
+
+		ciphertext, err := c.encryptLocked(toSend)
+		if err != nil {
+			return 0, err
+		}
+
+		if n, err := c.conn.Write(ciphertext); err != nil {
+			sent += n
+			// Return the raw error on the Write that actually
+			// failed. For future writes, return that error wrapped in
+			// a desync error.
+			c.tx.err = errPartialWrite{err}
+			return sent, err
+		}
+		sent += len(toSend)
+	}
+	return sent, nil
+}
+
+// Close implements io.Closer.
+func (c *Conn) Close() error {
+	closeErr := c.conn.Close() // unblocks any waiting reads or writes
+	c.rx.Lock()
+	c.rx.cipher = nil
+	c.rx.Unlock()
+	c.tx.Lock()
+	c.tx.cipher = nil
+	c.tx.Unlock()
+	return closeErr
+}
+
+func (c *Conn) LocalAddr() net.Addr                { return c.conn.LocalAddr() }
+func (c *Conn) RemoteAddr() net.Addr               { return c.conn.RemoteAddr() }
+func (c *Conn) SetDeadline(t time.Time) error      { return c.conn.SetDeadline(t) }
+func (c *Conn) SetReadDeadline(t time.Time) error  { return c.conn.SetReadDeadline(t) }
+func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
+
+// errCipherExhausted is the error returned when we run out of nonces
+// on a cipher.
+type errCipherExhausted struct{}
+
+func (errCipherExhausted) Error() string {
+	return "cipher exhausted, no more nonces available for current key"
+}
+func (errCipherExhausted) Timeout() bool   { return false }
+func (errCipherExhausted) Temporary() bool { return false }
+
+// errPartialWrite is the error returned when the cipher state has
+// become unusable due to a past partial write.
+type errPartialWrite struct {
+	err error
+}
+
+func (e errPartialWrite) Error() string {
+	return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
+}
+func (e errPartialWrite) Unwrap() error   { return e.err }
+func (e errPartialWrite) Temporary() bool { return false }
+func (e errPartialWrite) Timeout() bool   { return false }
+
+// errReadTooBig is the error returned when the peer sent an
+// unacceptably large Noise frame.
+type errReadTooBig struct {
+	requested int
+}
+
+func (e errReadTooBig) Error() string {
+	return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
+}
+func (e errReadTooBig) Temporary() bool {
+	// permanent error because this error only occurs when our peer
+	// sends us a frame so large we're unwilling to ever decode it.
+	return false
+}
+func (e errReadTooBig) Timeout() bool { return false }

+ 339 - 0
control/noise/conn_test.go

@@ -0,0 +1,339 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package noise
+
+import (
+	"bufio"
+	"bytes"
+	"context"
+	"crypto/rand"
+	"encoding/binary"
+	"fmt"
+	"io"
+	"net"
+	"strings"
+	"sync"
+	"testing"
+	"testing/iotest"
+
+	chp "golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/net/nettest"
+	tsnettest "tailscale.com/net/nettest"
+	"tailscale.com/types/key"
+)
+
+func TestMessageSize(t *testing.T) {
+	// This test is a regression guard against someone looking at
+	// maxCiphertextSize, going "huh, we could be more efficient if it
+	// were larger, and accidentally violating the Noise spec. Do not
+	// change this max value, it's a deliberate limitation of the
+	// cryptographic protocol we use (see Section 3 "Message Format"
+	// of the Noise spec).
+	const max = 65535
+	if maxCiphertextSize > max {
+		t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
+	}
+}
+
+func TestConnBasic(t *testing.T) {
+	client, server := pair(t)
+
+	sb := sinkReads(server)
+
+	want := "test"
+	if _, err := io.WriteString(client, want); err != nil {
+		t.Fatalf("client write failed: %v", err)
+	}
+	client.Close()
+
+	if got := sb.String(4); got != want {
+		t.Fatalf("wrong content received: got %q, want %q", got, want)
+	}
+	if err := sb.Error(); err != io.EOF {
+		t.Fatal("client close wasn't seen by server")
+	}
+	if sb.Total() != 4 {
+		t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
+	}
+}
+
+// bufferedWriteConn wraps a net.Conn and gives control over how
+// Writes get batched out.
+type bufferedWriteConn struct {
+	net.Conn
+	w           *bufio.Writer
+	manualFlush bool
+}
+
+func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
+	n, err := c.w.Write(bs)
+	if err == nil && !c.manualFlush {
+		err = c.w.Flush()
+	}
+	return n, err
+}
+
+// TestFastPath exercises the Read codepath that can receive multiple
+// Noise frames at once and decode each in turn without making another
+// syscall.
+func TestFastPath(t *testing.T) {
+	s1, s2 := tsnettest.NewConn("noise", 128000)
+	b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
+	client, server := pairWithConns(t, b, s2)
+
+	b.manualFlush = true
+
+	sb := sinkReads(server)
+
+	const packets = 10
+	s := "test"
+	for i := 0; i < packets; i++ {
+		// Many separate writes, to force separate Noise frames that
+		// all get buffered up and then all sent as a single slice to
+		// the server.
+		if _, err := io.WriteString(client, s); err != nil {
+			t.Fatalf("client write1 failed: %v", err)
+		}
+	}
+	if err := b.w.Flush(); err != nil {
+		t.Fatalf("client flush failed: %v", err)
+	}
+	client.Close()
+
+	want := strings.Repeat(s, packets)
+	if got := sb.String(len(want)); got != want {
+		t.Fatalf("wrong content received: got %q, want %q", got, want)
+	}
+	if err := sb.Error(); err != io.EOF {
+		t.Fatalf("client close wasn't seen by server")
+	}
+}
+
+// Writes things larger than a single Noise frame, to check the
+// chunking on the encoder and decoder.
+func TestBigData(t *testing.T) {
+	client, server := pair(t)
+
+	serverReads := sinkReads(server)
+	clientReads := sinkReads(client)
+
+	const sz = 15 * 1024 // 15KiB
+	clientStr := strings.Repeat("abcde", sz/5)
+	serverStr := strings.Repeat("fghij", sz/5*2)
+
+	if _, err := io.WriteString(client, clientStr); err != nil {
+		t.Fatalf("writing client>server: %v", err)
+	}
+	if _, err := io.WriteString(server, serverStr); err != nil {
+		t.Fatalf("writing server>client: %v", err)
+	}
+
+	if serverGot := serverReads.String(sz); serverGot != clientStr {
+		t.Error("server didn't receive what client sent")
+	}
+	if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
+		t.Error("client didn't receive what server sent")
+	}
+
+	getNonce := func(n [chp.NonceSize]byte) uint64 {
+		if binary.BigEndian.Uint32(n[:4]) != 0 {
+			panic("unexpected nonce")
+		}
+		return binary.BigEndian.Uint64(n[4:])
+	}
+
+	// Reach into the Conns and verify the cipher nonces advanced as
+	// expected.
+	if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
+		t.Error("desynchronized client tx nonce")
+	}
+	if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
+		t.Error("desynchronized server tx nonce")
+	}
+	if n := getNonce(client.tx.nonce); n != 4 {
+		t.Errorf("wrong client tx nonce, got %d want 4", n)
+	}
+	if n := getNonce(server.tx.nonce); n != 8 {
+		t.Errorf("wrong client tx nonce, got %d want 8", n)
+	}
+}
+
+// readerConn wraps a net.Conn and routes its Reads through a separate
+// io.Reader.
+type readerConn struct {
+	net.Conn
+	r io.Reader
+}
+
+func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
+
+// Check that the receiver can handle not being able to read an entire
+// frame in a single syscall.
+func TestDataTrickle(t *testing.T) {
+	s1, s2 := tsnettest.NewConn("noise", 128000)
+	client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
+	serverReads := sinkReads(server)
+
+	const sz = 10000
+	clientStr := strings.Repeat("abcde", sz/5)
+	if _, err := io.WriteString(client, clientStr); err != nil {
+		t.Fatalf("writing client>server: %v", err)
+	}
+
+	serverGot := serverReads.String(sz)
+	if serverGot != clientStr {
+		t.Error("server didn't receive what client sent")
+	}
+}
+
+func TestConnStd(t *testing.T) {
+	// You can run this test manually, and noise.Conn should pass all
+	// of them except for TestConn/PastTimeout,
+	// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
+	// those tests assume that write errors are recoverable, and
+	// they're not on our Conn due to cipher security.
+	t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
+	nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
+		s1, s2 := tsnettest.NewConn("noise", 4096)
+		controlKey := key.NewPrivate()
+		machineKey := key.NewPrivate()
+		serverErr := make(chan error, 1)
+		go func() {
+			var err error
+			c2, err = Server(context.Background(), s2, controlKey)
+			serverErr <- err
+		}()
+		c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
+		if err != nil {
+			s1.Close()
+			s2.Close()
+			return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
+		}
+		if err := <-serverErr; err != nil {
+			c1.Close()
+			s1.Close()
+			s2.Close()
+			return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
+		}
+		return c1, c2, func() {
+			c1.Close()
+			c2.Close()
+		}, nil
+	})
+}
+
+// mkConns creates synthetic Noise Conns wrapping the given net.Conns.
+// This function is for testing just the Conn transport logic without
+// having to muck about with Noise handshakes.
+func mkConns(s1, s2 net.Conn) (*Conn, *Conn) {
+	var k1, k2 [chp.KeySize]byte
+	if _, err := rand.Read(k1[:]); err != nil {
+		panic(err)
+	}
+	if _, err := rand.Read(k2[:]); err != nil {
+		panic(err)
+	}
+
+	ret1 := &Conn{
+		conn: s1,
+		tx:   txState{cipher: newCHP(k1)},
+		rx:   rxState{cipher: newCHP(k2)},
+	}
+	ret2 := &Conn{
+		conn: s2,
+		tx:   txState{cipher: newCHP(k2)},
+		rx:   rxState{cipher: newCHP(k1)},
+	}
+
+	return ret1, ret2
+}
+
+type readSink struct {
+	r io.Reader
+
+	cond *sync.Cond
+	sync.Mutex
+	bs  bytes.Buffer
+	err error
+}
+
+func sinkReads(r io.Reader) *readSink {
+	ret := &readSink{
+		r: r,
+	}
+	ret.cond = sync.NewCond(&ret.Mutex)
+	go func() {
+		var buf [4096]byte
+		for {
+			n, err := r.Read(buf[:])
+			ret.Lock()
+			ret.bs.Write(buf[:n])
+			if err != nil {
+				ret.err = err
+			}
+			ret.cond.Broadcast()
+			ret.Unlock()
+			if err != nil {
+				return
+			}
+		}
+	}()
+	return ret
+}
+
+func (s *readSink) String(total int) string {
+	s.Lock()
+	defer s.Unlock()
+	for s.bs.Len() < total && s.err == nil {
+		s.cond.Wait()
+	}
+	if s.err != nil {
+		total = s.bs.Len()
+	}
+	return string(s.bs.Bytes()[:total])
+}
+
+func (s *readSink) Error() error {
+	s.Lock()
+	defer s.Unlock()
+	for s.err == nil {
+		s.cond.Wait()
+	}
+	return s.err
+}
+
+func (s *readSink) Total() int {
+	s.Lock()
+	defer s.Unlock()
+	return s.bs.Len()
+}
+
+func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
+	var (
+		controlKey = key.NewPrivate()
+		machineKey = key.NewPrivate()
+		server     *Conn
+		serverErr  = make(chan error, 1)
+	)
+	go func() {
+		var err error
+		server, err = Server(context.Background(), serverConn, controlKey)
+		serverErr <- err
+	}()
+
+	client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	if err := <-serverErr; err != nil {
+		t.Fatalf("server connection failed: %v", err)
+	}
+	return client, server
+}
+
+func pair(t *testing.T) (*Conn, *Conn) {
+	s1, s2 := tsnettest.NewConn("noise", 128000)
+	return pairWithConns(t, s1, s2)
+}

+ 361 - 0
control/noise/handshake.go

@@ -0,0 +1,361 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package noise
+
+import (
+	"context"
+	"crypto/cipher"
+	"encoding/binary"
+	"fmt"
+	"hash"
+	"io"
+	"net"
+	"time"
+
+	"golang.org/x/crypto/blake2s"
+	chp "golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/hkdf"
+	"tailscale.com/types/key"
+)
+
+const (
+	protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
+	invalidNonce = ^uint64(0)
+)
+
+// Client initiates a Noise client handshake, returning the resulting
+// Noise connection.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) {
+	if deadline, ok := ctx.Deadline(); ok {
+		if err := conn.SetDeadline(deadline); err != nil {
+			return nil, fmt.Errorf("setting conn deadline: %w", err)
+		}
+		defer func() {
+			conn.SetDeadline(time.Time{})
+		}()
+	}
+
+	var s symmetricState
+	s.Initialize()
+
+	// <- s
+	// ...
+	s.MixHash(controlKey[:])
+
+	var init initiationMessage
+	// -> e, es, s, ss
+	machineEphemeral := key.NewPrivate()
+	machineEphemeralPub := machineEphemeral.Public()
+	copy(init.MachineEphemeralPub(), machineEphemeralPub[:])
+	s.MixHash(machineEphemeralPub[:])
+	if err := s.MixDH(machineEphemeral, controlKey); err != nil {
+		return nil, fmt.Errorf("computing es: %w", err)
+	}
+	machineKeyPub := machineKey.Public()
+	copy(init.MachinePub(), s.EncryptAndHash(machineKeyPub[:]))
+	if err := s.MixDH(machineKey, controlKey); err != nil {
+		return nil, fmt.Errorf("computing ss: %w", err)
+	}
+	copy(init.Tag(), s.EncryptAndHash(nil)) // empty message payload
+
+	if _, err := conn.Write(init[:]); err != nil {
+		return nil, fmt.Errorf("writing initiation: %w", err)
+	}
+
+	// <- e, ee, se
+	var resp responseMessage
+	if _, err := io.ReadFull(conn, resp[:]); err != nil {
+		return nil, fmt.Errorf("reading response: %w", err)
+	}
+
+	var controlEphemeralPub key.Public
+	copy(controlEphemeralPub[:], resp.ControlEphemeralPub())
+	s.MixHash(controlEphemeralPub[:])
+	if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
+		return nil, fmt.Errorf("computing ee: %w", err)
+	}
+	if err := s.MixDH(machineKey, controlEphemeralPub); err != nil {
+		return nil, fmt.Errorf("computing se: %w", err)
+	}
+	if _, err := s.DecryptAndHash(resp.Tag()); err != nil {
+		return nil, fmt.Errorf("decrypting payload: %w", err)
+	}
+
+	c1, c2, err := s.Split()
+	if err != nil {
+		return nil, fmt.Errorf("finalizing handshake: %w", err)
+	}
+
+	return &Conn{
+		conn:          conn,
+		peer:          controlKey,
+		handshakeHash: s.h,
+		tx: txState{
+			cipher: c1,
+		},
+		rx: rxState{
+			cipher: c2,
+		},
+	}, nil
+}
+
+// Server initiates a Noise server handshake, returning the resulting
+// Noise connection.
+//
+// The context deadline, if any, covers the entire handshaking
+// process.
+func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, error) {
+	if deadline, ok := ctx.Deadline(); ok {
+		if err := conn.SetDeadline(deadline); err != nil {
+			return nil, fmt.Errorf("setting conn deadline: %w", err)
+		}
+		defer func() {
+			conn.SetDeadline(time.Time{})
+		}()
+	}
+
+	var s symmetricState
+	s.Initialize()
+
+	// <- s
+	// ...
+	controlKeyPub := controlKey.Public()
+	s.MixHash(controlKeyPub[:])
+
+	// -> e, es, s, ss
+	var init initiationMessage
+	if _, err := io.ReadFull(conn, init[:]); err != nil {
+		return nil, fmt.Errorf("reading initiation: %w", err)
+	}
+
+	var machineEphemeralPub key.Public
+	copy(machineEphemeralPub[:], init.MachineEphemeralPub())
+	s.MixHash(machineEphemeralPub[:])
+	if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
+		return nil, fmt.Errorf("computing es: %w", err)
+	}
+	var machineKey key.Public
+	rs, err := s.DecryptAndHash(init.MachinePub())
+	if err != nil {
+		return nil, fmt.Errorf("decrypting machine key: %w", err)
+	}
+	copy(machineKey[:], rs)
+	if err := s.MixDH(controlKey, machineKey); err != nil {
+		return nil, fmt.Errorf("computing ss: %w", err)
+	}
+	if _, err := s.DecryptAndHash(init.Tag()); err != nil {
+		return nil, fmt.Errorf("decrypting initiation tag: %w", err)
+	}
+
+	// <- e, ee, se
+	var resp responseMessage
+	controlEphemeral := key.NewPrivate()
+	controlEphemeralPub := controlEphemeral.Public()
+	copy(resp.ControlEphemeralPub(), controlEphemeralPub[:])
+	s.MixHash(controlEphemeralPub[:])
+	if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
+		return nil, fmt.Errorf("computing ee: %w", err)
+	}
+	if err := s.MixDH(controlEphemeral, machineKey); err != nil {
+		return nil, fmt.Errorf("computing se: %w", err)
+	}
+	copy(resp.Tag(), s.EncryptAndHash(nil)) // empty message payload
+
+	c1, c2, err := s.Split()
+	if err != nil {
+		return nil, fmt.Errorf("finalizing handshake: %w", err)
+	}
+
+	if _, err := conn.Write(resp[:]); err != nil {
+		return nil, err
+	}
+
+	return &Conn{
+		conn:          conn,
+		peer:          machineKey,
+		handshakeHash: s.h,
+		tx: txState{
+			cipher: c2,
+		},
+		rx: rxState{
+			cipher: c1,
+		},
+	}, nil
+}
+
+// initiationMessage is the Noise protocol message sent from a client
+// machine to a control server.
+type initiationMessage [96]byte
+
+func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] }
+func (m *initiationMessage) MachinePub() []byte          { return m[32:80] }
+func (m *initiationMessage) Tag() []byte                 { return m[80:] }
+
+// responseMessage is the Noise protocol message sent from a control
+// server to a client machine.
+type responseMessage [48]byte
+
+func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] }
+func (m *responseMessage) Tag() []byte                 { return m[32:] }
+
+// symmetricState is the SymmetricState object from the Noise protocol
+// spec. It contains all the symmetric cipher state of an in-flight
+// handshake. Field names match the variable names in the spec.
+type symmetricState struct {
+	h  [blake2s.Size]byte
+	ck [blake2s.Size]byte
+
+	k [chp.KeySize]byte
+	n uint64
+
+	mixer hash.Hash // for updating h
+}
+
+// Initialize sets s to the initial handshake state, prior to
+// processing any Noise messages.
+func (s *symmetricState) Initialize() {
+	if s.mixer != nil {
+		panic("symmetricState cannot be reused")
+	}
+	s.h = blake2s.Sum256([]byte(protocolName))
+	s.ck = s.h
+	s.k = [chp.KeySize]byte{}
+	s.n = invalidNonce
+	s.mixer = newBLAKE2s()
+	// Mix in an empty prologue.
+	s.MixHash(nil)
+}
+
+// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
+// concatenation.
+func (s *symmetricState) MixHash(data []byte) {
+	s.mixer.Reset()
+	s.mixer.Write(s.h[:])
+	s.mixer.Write(data)
+	s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly...
+}
+
+// MixDH updates s.ck and s.k with the result of X25519(priv, pub).
+//
+// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
+// it as a single function allows for strongly-typed arguments that
+// reduce the risk of error in the caller (e.g. invoking X25519 with
+// two private keys, or two public keys), and thus producing the wrong
+// calculation.
+func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error {
+	// TODO(danderson): check that this operation is correct. The docs
+	// for X25519 say that the 2nd arg must be either Basepoint or the
+	// output of another X25519 call.
+	//
+	// I think this is correct, because pub is the result of a
+	// ScalarBaseMult on the private key, and our private key
+	// generation code clamps keys to avoid low order points. I
+	// believe that makes pub equivalent to the output of
+	// X25519(privateKey, Basepoint), and so the contract is
+	// respected.
+	keyData, err := curve25519.X25519(priv[:], pub[:])
+	if err != nil {
+		return fmt.Errorf("computing X25519: %w", err)
+	}
+
+	r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
+	if _, err := io.ReadFull(r, s.ck[:]); err != nil {
+		return fmt.Errorf("extracting ck: %w", err)
+	}
+	if _, err := io.ReadFull(r, s.k[:]); err != nil {
+		return fmt.Errorf("extracting k: %w", err)
+	}
+	s.n = 0
+	return nil
+}
+
+// EncryptAndHash encrypts the given plaintext using the current s.k,
+// mixes the ciphertext into s.h, and returns the ciphertext.
+func (s *symmetricState) EncryptAndHash(plaintext []byte) []byte {
+	if s.n == invalidNonce {
+		// Noise in general permits writing "ciphertext" without a
+		// key, but in IK it cannot happen.
+		panic("attempted encryption with uninitialized key")
+	}
+	aead := newCHP(s.k)
+	var nonce [chp.NonceSize]byte
+	binary.BigEndian.PutUint64(nonce[4:], s.n)
+	s.n++
+	ret := aead.Seal(nil, nonce[:], plaintext, s.h[:])
+	s.MixHash(ret)
+	return ret
+}
+
+// DecryptAndHash decrypts the given ciphertext using the current
+// s.k. If decryption is successful, it mixes the ciphertext into s.h
+// and returns the plaintext.
+func (s *symmetricState) DecryptAndHash(ciphertext []byte) ([]byte, error) {
+	if s.n == invalidNonce {
+		// Noise in general permits "ciphertext" without a key, but in
+		// IK it cannot happen.
+		panic("attempted encryption with uninitialized key")
+	}
+	aead := newCHP(s.k)
+	var nonce [chp.NonceSize]byte
+	binary.BigEndian.PutUint64(nonce[4:], s.n)
+	s.n++
+	ret, err := aead.Open(nil, nonce[:], ciphertext, s.h[:])
+	if err != nil {
+		return nil, err
+	}
+	s.MixHash(ciphertext)
+	return ret, nil
+}
+
+// Split returns two ChaCha20Poly1305 ciphers with keys derives from
+// the current handshake state. Methods on s must not be used again
+// after calling Split().
+func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
+	var k1, k2 [chp.KeySize]byte
+	r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
+	if _, err := io.ReadFull(r, k1[:]); err != nil {
+		return nil, nil, fmt.Errorf("extracting k1: %w", err)
+	}
+	if _, err := io.ReadFull(r, k2[:]); err != nil {
+		return nil, nil, fmt.Errorf("extracting k2: %w", err)
+	}
+	c1, err = chp.New(k1[:])
+	if err != nil {
+		return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
+	}
+	c2, err = chp.New(k2[:])
+	if err != nil {
+		return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
+	}
+	return c1, c2, nil
+}
+
+// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
+// error.
+func newBLAKE2s() hash.Hash {
+	h, err := blake2s.New256(nil)
+	if err != nil {
+		// Should never happen, errors only happen when using BLAKE2s
+		// in MAC mode with a key.
+		panic(fmt.Sprintf("blake2s construction: %v", err))
+	}
+	return h
+}
+
+// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
+// panics on error.
+func newCHP(key [chp.KeySize]byte) cipher.AEAD {
+	aead, err := chp.New(key[:])
+	if err != nil {
+		// Can only happen if we passed a key of the wrong length. The
+		// function signature prevents that.
+		panic(fmt.Sprintf("chacha20poly1305 construction: %v", err))
+	}
+	return aead
+}

+ 290 - 0
control/noise/handshake_test.go

@@ -0,0 +1,290 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package noise
+
+import (
+	"bytes"
+	"context"
+	"io"
+	"strings"
+	"testing"
+	"time"
+
+	tsnettest "tailscale.com/net/nettest"
+	"tailscale.com/types/key"
+)
+
+func TestHandshake(t *testing.T) {
+	var (
+		clientConn, serverConn = tsnettest.NewConn("noise", 128000)
+		serverKey              = key.NewPrivate()
+		clientKey              = key.NewPrivate()
+		server                 *Conn
+		serverErr              = make(chan error, 1)
+	)
+	go func() {
+		var err error
+		server, err = Server(context.Background(), serverConn, serverKey)
+		serverErr <- err
+	}()
+
+	client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	if err := <-serverErr; err != nil {
+		t.Fatalf("server connection failed: %v", err)
+	}
+
+	if client.HandshakeHash() != server.HandshakeHash() {
+		t.Fatal("client and server disagree on handshake hash")
+	}
+
+	if client.Peer() != serverKey.Public() {
+		t.Fatal("client peer key isn't serverKey")
+	}
+	if server.Peer() != clientKey.Public() {
+		t.Fatal("client peer key isn't serverKey")
+	}
+}
+
+// Check that handshaking repeatedly with the same long-term keys
+// result in different handshake hashes and wire traffic.
+func TestNoReuse(t *testing.T) {
+	var (
+		hashes           = map[[32]byte]bool{}
+		clientHandshakes = map[[96]byte]bool{}
+		serverHandshakes = map[[48]byte]bool{}
+		packets          = map[[32]byte]bool{}
+	)
+	for i := 0; i < 10; i++ {
+		var (
+			clientRaw, serverRaw = tsnettest.NewConn("noise", 128000)
+			clientBuf, serverBuf bytes.Buffer
+			clientConn           = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)}
+			serverConn           = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)}
+			serverKey            = key.NewPrivate()
+			clientKey            = key.NewPrivate()
+			server               *Conn
+			serverErr            = make(chan error, 1)
+		)
+		go func() {
+			var err error
+			server, err = Server(context.Background(), serverConn, serverKey)
+			serverErr <- err
+		}()
+
+		client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+		if err != nil {
+			t.Fatalf("client connection failed: %v", err)
+		}
+		if err := <-serverErr; err != nil {
+			t.Fatalf("server connection failed: %v", err)
+		}
+
+		var clientHS [96]byte
+		copy(clientHS[:], serverBuf.Bytes())
+		if clientHandshakes[clientHS] {
+			t.Fatal("client handshake seen twice")
+		}
+		clientHandshakes[clientHS] = true
+
+		var serverHS [48]byte
+		copy(serverHS[:], clientBuf.Bytes())
+		if serverHandshakes[serverHS] {
+			t.Fatal("server handshake seen twice")
+		}
+		serverHandshakes[serverHS] = true
+
+		clientBuf.Reset()
+		serverBuf.Reset()
+		cb := sinkReads(client)
+		sb := sinkReads(server)
+
+		if hashes[client.HandshakeHash()] {
+			t.Fatalf("handshake hash %v seen twice", client.HandshakeHash())
+		}
+		hashes[client.HandshakeHash()] = true
+
+		// Sending 14 bytes turns into 32 bytes on the wire (+16 for
+		// the poly1305 tag, +2 length header)
+		if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
+			t.Fatalf("client>server write failed: %v", err)
+		}
+		if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil {
+			t.Fatalf("server>client write failed: %v", err)
+		}
+
+		// Wait for the bytes to be read, so we know they've traveled end to end
+		cb.String(14)
+		sb.String(14)
+
+		var clientWire, serverWire [32]byte
+		copy(clientWire[:], clientBuf.Bytes())
+		copy(serverWire[:], serverBuf.Bytes())
+
+		if packets[clientWire] {
+			t.Fatalf("client wire traffic seen twice")
+		}
+		packets[clientWire] = true
+		if packets[serverWire] {
+			t.Fatalf("server wire traffic seen twice")
+		}
+		packets[serverWire] = true
+	}
+}
+
+// tamperReader wraps a reader and mutates the Nth byte.
+type tamperReader struct {
+	r     io.Reader
+	n     int
+	total int
+}
+
+func (r *tamperReader) Read(bs []byte) (int, error) {
+	n, err := r.r.Read(bs)
+	if off := r.n - r.total; off >= 0 && off < n {
+		bs[off] += 1
+	}
+	r.total += n
+	return n, err
+}
+
+func TestTampering(t *testing.T) {
+	// Tamper with every byte of the client initiation message.
+	for i := 0; i < 96; i++ {
+		var (
+			clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
+			serverConn            = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
+			serverKey             = key.NewPrivate()
+			clientKey             = key.NewPrivate()
+			serverErr             = make(chan error, 1)
+		)
+		go func() {
+			_, err := Server(context.Background(), serverConn, serverKey)
+			// If the server failed, we have to close the Conn to
+			// unblock the client.
+			if err != nil {
+				serverConn.Close()
+			}
+			serverErr <- err
+		}()
+
+		_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+		if err == nil {
+			t.Fatal("client connection succeeded despite tampering")
+		}
+		if err := <-serverErr; err == nil {
+			t.Fatalf("server connection succeeded despite tampering")
+		}
+	}
+
+	// Tamper with every byte of the server response message.
+	for i := 0; i < 48; i++ {
+		var (
+			clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
+			clientConn            = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
+			serverKey             = key.NewPrivate()
+			clientKey             = key.NewPrivate()
+			serverErr             = make(chan error, 1)
+		)
+		go func() {
+			_, err := Server(context.Background(), serverConn, serverKey)
+			serverErr <- err
+		}()
+
+		_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+		if err == nil {
+			t.Fatal("client connection succeeded despite tampering")
+		}
+		// The server shouldn't fail, because the tampering took place
+		// in its response.
+		if err := <-serverErr; err != nil {
+			t.Fatalf("server connection failed despite no tampering: %v", err)
+		}
+	}
+
+	// Tamper with every byte of the first server>client transport message.
+	for i := 0; i < 32; i++ {
+		var (
+			clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
+			clientConn            = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}}
+			serverKey             = key.NewPrivate()
+			clientKey             = key.NewPrivate()
+			serverErr             = make(chan error, 1)
+		)
+		go func() {
+			server, err := Server(context.Background(), serverConn, serverKey)
+			serverErr <- err
+			_, err = io.WriteString(server, strings.Repeat("a", 14))
+			serverErr <- err
+		}()
+
+		client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+		if err != nil {
+			t.Fatalf("client handshake failed: %v", err)
+		}
+		// The server shouldn't fail, because the tampering took place
+		// in its response.
+		if err := <-serverErr; err != nil {
+			t.Fatalf("server handshake failed: %v", err)
+		}
+
+		// The client needs a timeout if the tampering is hitting the length header.
+		if i == 0 || i == 1 {
+			client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+		}
+
+		var bs [100]byte
+		n, err := client.Read(bs[:])
+		if err == nil {
+			t.Fatal("read succeeded despite tampering")
+		}
+		if n != 0 {
+			t.Fatal("conn yielded some bytes despite tampering")
+		}
+	}
+
+	// Tamper with every byte of the first client>server transport message.
+	for i := 0; i < 32; i++ {
+		var (
+			clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
+			serverConn            = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}}
+			serverKey             = key.NewPrivate()
+			clientKey             = key.NewPrivate()
+			serverErr             = make(chan error, 1)
+		)
+		go func() {
+			server, err := Server(context.Background(), serverConn, serverKey)
+			serverErr <- err
+			var bs [100]byte
+			// The server needs a timeout if the tampering is hitting the length header.
+			if i == 0 || i == 1 {
+				server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+			}
+			n, err := server.Read(bs[:])
+			if n != 0 {
+				panic("server got bytes despite tampering")
+			} else {
+				serverErr <- err
+			}
+		}()
+
+		client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
+		if err != nil {
+			t.Fatalf("client handshake failed: %v", err)
+		}
+		if err := <-serverErr; err != nil {
+			t.Fatalf("server handshake failed: %v", err)
+		}
+
+		if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
+			t.Fatalf("client>server write failed: %v", err)
+		}
+		if err := <-serverErr; err == nil {
+			t.Fatal("server successfully received bytes despite tampering")
+		}
+	}
+}

+ 238 - 0
control/noise/interop_test.go

@@ -0,0 +1,238 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package noise
+
+import (
+	"context"
+	"encoding/binary"
+	"errors"
+	"io"
+	"net"
+	"testing"
+
+	tsnettest "tailscale.com/net/nettest"
+	"tailscale.com/types/key"
+)
+
+// Can a reference Noise IK client talk to our server?
+func TestInteropClient(t *testing.T) {
+	var (
+		s1, s2      = tsnettest.NewConn("noise", 128000)
+		controlKey  = key.NewPrivate()
+		machineKey  = key.NewPrivate()
+		serverErr   = make(chan error, 2)
+		serverBytes = make(chan []byte, 1)
+		c2s         = "client>server"
+		s2c         = "server>client"
+	)
+
+	go func() {
+		server, err := Server(context.Background(), s2, controlKey)
+		serverErr <- err
+		if err != nil {
+			return
+		}
+		var buf [1024]byte
+		_, err = io.ReadFull(server, buf[:len(c2s)])
+		serverBytes <- buf[:len(c2s)]
+		if err != nil {
+			serverErr <- err
+			return
+		}
+		_, err = server.Write([]byte(s2c))
+		serverErr <- err
+	}()
+
+	gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
+	if err != nil {
+		t.Fatalf("failed client interop: %v", err)
+	}
+	if string(gotS2C) != s2c {
+		t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
+	}
+
+	if err := <-serverErr; err != nil {
+		t.Fatalf("server handshake failed: %v", err)
+	}
+	if err := <-serverErr; err != nil {
+		t.Fatalf("server read/write failed: %v", err)
+	}
+	if got := string(<-serverBytes); got != c2s {
+		t.Fatalf("server received %q, want %q", got, c2s)
+	}
+}
+
+// Can our client talk to a reference Noise IK server?
+func TestInteropServer(t *testing.T) {
+	var (
+		s1, s2      = tsnettest.NewConn("noise", 128000)
+		controlKey  = key.NewPrivate()
+		machineKey  = key.NewPrivate()
+		clientErr   = make(chan error, 2)
+		clientBytes = make(chan []byte, 1)
+		c2s         = "client>server"
+		s2c         = "server>client"
+	)
+
+	go func() {
+		client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
+		clientErr <- err
+		if err != nil {
+			return
+		}
+		_, err = client.Write([]byte(c2s))
+		if err != nil {
+			clientErr <- err
+			return
+		}
+		var buf [1024]byte
+		_, err = io.ReadFull(client, buf[:len(s2c)])
+		clientBytes <- buf[:len(s2c)]
+		clientErr <- err
+	}()
+
+	gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
+	if err != nil {
+		t.Fatalf("failed server interop: %v", err)
+	}
+	if string(gotC2S) != c2s {
+		t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
+	}
+
+	if err := <-clientErr; err != nil {
+		t.Fatalf("client handshake failed: %v", err)
+	}
+	if err := <-clientErr; err != nil {
+		t.Fatalf("client read/write failed: %v", err)
+	}
+	if got := string(<-clientBytes); got != s2c {
+		t.Fatalf("client received %q, want %q", got, s2c)
+	}
+}
+
+// noiseExplorerClient uses the Noise Explorer implementation of Noise
+// IK to handshake as a Noise client on conn, transmit payload, and
+// read+return a payload from the peer.
+func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Private, payload []byte) ([]byte, error) {
+	mk := keypair{
+		private_key: machineKey,
+		public_key:  machineKey.Public(),
+	}
+	session := InitSession(true, nil, mk, controlKey)
+
+	_, msg1 := SendMessage(&session, nil)
+	if _, err := conn.Write(msg1.ne[:]); err != nil {
+		return nil, err
+	}
+	if _, err := conn.Write(msg1.ns); err != nil {
+		return nil, err
+	}
+	if _, err := conn.Write(msg1.ciphertext); err != nil {
+		return nil, err
+	}
+
+	var buf [1024]byte
+	if _, err := io.ReadFull(conn, buf[:48]); err != nil {
+		return nil, err
+	}
+	msg2 := messagebuffer{
+		ciphertext: buf[32:48],
+	}
+	copy(msg2.ne[:], buf[:32])
+	_, p, valid := RecvMessage(&session, &msg2)
+	if !valid {
+		return nil, errors.New("handshake failed")
+	}
+	if len(p) != 0 {
+		return nil, errors.New("non-empty payload")
+	}
+
+	_, msg3 := SendMessage(&session, payload)
+	binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext)))
+	if _, err := conn.Write(buf[:2]); err != nil {
+		return nil, err
+	}
+	if _, err := conn.Write(msg3.ciphertext); err != nil {
+		return nil, err
+	}
+
+	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
+		return nil, err
+	}
+	plen := int(binary.BigEndian.Uint16(buf[:2]))
+	if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+		return nil, err
+	}
+
+	msg4 := messagebuffer{
+		ciphertext: buf[:plen],
+	}
+	_, p, valid = RecvMessage(&session, &msg4)
+	if !valid {
+		return nil, errors.New("transport message decryption failed")
+	}
+
+	return p, nil
+}
+
+func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey key.Public, payload []byte) ([]byte, error) {
+	mk := keypair{
+		private_key: controlKey,
+		public_key:  controlKey.Public(),
+	}
+	session := InitSession(false, nil, mk, [32]byte{})
+
+	var buf [1024]byte
+	if _, err := io.ReadFull(conn, buf[:96]); err != nil {
+		return nil, err
+	}
+	msg1 := messagebuffer{
+		ns:         buf[32:80],
+		ciphertext: buf[80:96],
+	}
+	copy(msg1.ne[:], buf[:32])
+	_, p, valid := RecvMessage(&session, &msg1)
+	if !valid {
+		return nil, errors.New("handshake failed")
+	}
+	if len(p) != 0 {
+		return nil, errors.New("non-empty payload")
+	}
+
+	_, msg2 := SendMessage(&session, nil)
+	if _, err := conn.Write(msg2.ne[:]); err != nil {
+		return nil, err
+	}
+	if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
+		return nil, err
+	}
+
+	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
+		return nil, err
+	}
+	plen := int(binary.BigEndian.Uint16(buf[:2]))
+	if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
+		return nil, err
+	}
+
+	msg3 := messagebuffer{
+		ciphertext: buf[:plen],
+	}
+	_, p, valid = RecvMessage(&session, &msg3)
+	if !valid {
+		return nil, errors.New("transport message decryption failed")
+	}
+
+	_, msg4 := SendMessage(&session, payload)
+	binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext)))
+	if _, err := conn.Write(buf[:2]); err != nil {
+		return nil, err
+	}
+	if _, err := conn.Write(msg4.ciphertext); err != nil {
+		return nil, err
+	}
+
+	return p, nil
+}

+ 475 - 0
control/noise/noiseexplorer_test.go

@@ -0,0 +1,475 @@
+// This file contains the implementation of Noise IK from
+// https://noiseexplorer.com/ . Unlike the rest of this repository,
+// this file is licensed under the terms of the GNU GPL v3. See
+// https://source.symbolic.software/noiseexplorer/noiseexplorer for
+// more information.
+//
+// This file is used here to verify that Tailscale's implementation of
+// Noise IK is interoperable with another implementation.
+//lint:file-ignore SA4006 not our code.
+
+/*
+IK:
+  <- s
+  ...
+  -> e, es, s, ss
+  <- e, ee, se
+  ->
+  <-
+*/
+
+// Implementation Version: 1.0.2
+
+/* ---------------------------------------------------------------- *
+ * PARAMETERS                                                       *
+ * ---------------------------------------------------------------- */
+
+package noise
+
+import (
+	"crypto/rand"
+	"crypto/subtle"
+	"encoding/binary"
+	"hash"
+	"io"
+	"math"
+
+	"golang.org/x/crypto/blake2s"
+	"golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/hkdf"
+)
+
+/* ---------------------------------------------------------------- *
+ * TYPES                                                            *
+ * ---------------------------------------------------------------- */
+
+type keypair struct {
+	public_key  [32]byte
+	private_key [32]byte
+}
+
+type messagebuffer struct {
+	ne         [32]byte
+	ns         []byte
+	ciphertext []byte
+}
+
+type cipherstate struct {
+	k [32]byte
+	n uint32
+}
+
+type symmetricstate struct {
+	cs cipherstate
+	ck [32]byte
+	h  [32]byte
+}
+
+type handshakestate struct {
+	ss  symmetricstate
+	s   keypair
+	e   keypair
+	rs  [32]byte
+	re  [32]byte
+	psk [32]byte
+}
+
+type noisesession struct {
+	hs  handshakestate
+	h   [32]byte
+	cs1 cipherstate
+	cs2 cipherstate
+	mc  uint64
+	i   bool
+}
+
+/* ---------------------------------------------------------------- *
+ * CONSTANTS                                                        *
+ * ---------------------------------------------------------------- */
+
+var emptyKey = [32]byte{
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+	0x00, 0x00, 0x00, 0x00,
+}
+
+var minNonce = uint32(0)
+
+/* ---------------------------------------------------------------- *
+ * UTILITY FUNCTIONS                                                *
+ * ---------------------------------------------------------------- */
+
+func getPublicKey(kp *keypair) [32]byte {
+	return kp.public_key
+}
+
+func isEmptyKey(k [32]byte) bool {
+	return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1
+}
+
+func validatePublicKey(k []byte) bool {
+	forbiddenCurveValues := [12][]byte{
+		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+		{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+		{224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0},
+		{95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87},
+		{236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
+		{237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
+		{238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
+		{205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128},
+		{76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215},
+		{217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
+		{218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
+		{219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25},
+	}
+
+	for _, testValue := range forbiddenCurveValues {
+		if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 {
+			panic("Invalid public key")
+		}
+	}
+	return true
+}
+
+/* ---------------------------------------------------------------- *
+ * PRIMITIVES                                                       *
+ * ---------------------------------------------------------------- */
+
+func incrementNonce(n uint32) uint32 {
+	return n + 1
+}
+
+func dh(private_key [32]byte, public_key [32]byte) [32]byte {
+	var ss [32]byte
+	curve25519.ScalarMult(&ss, &private_key, &public_key)
+	return ss
+}
+
+func generateKeypair() keypair {
+	var public_key [32]byte
+	var private_key [32]byte
+	_, _ = rand.Read(private_key[:])
+	curve25519.ScalarBaseMult(&public_key, &private_key)
+	if validatePublicKey(public_key[:]) {
+		return keypair{public_key, private_key}
+	}
+	return generateKeypair()
+}
+
+func generatePublicKey(private_key [32]byte) [32]byte {
+	var public_key [32]byte
+	curve25519.ScalarBaseMult(&public_key, &private_key)
+	return public_key
+}
+
+func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte {
+	var nonce [12]byte
+	var ciphertext []byte
+	enc, _ := chacha20poly1305.New(k[:])
+	binary.LittleEndian.PutUint32(nonce[4:], n)
+	ciphertext = enc.Seal(nil, nonce[:], plaintext, ad)
+	return ciphertext
+}
+
+func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) {
+	var nonce [12]byte
+	var plaintext []byte
+	enc, err := chacha20poly1305.New(k[:])
+	binary.LittleEndian.PutUint32(nonce[4:], n)
+	plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad)
+	return (err == nil), ad, plaintext
+}
+
+func getHash(a []byte, b []byte) [32]byte {
+	return blake2s.Sum256(append(a, b...))
+}
+
+func hashProtocolName(protocolName []byte) [32]byte {
+	var h [32]byte
+	if len(protocolName) <= 32 {
+		copy(h[:], protocolName)
+	} else {
+		h = getHash(protocolName, []byte{})
+	}
+	return h
+}
+
+func blake2HkdfInterface() hash.Hash {
+	h, _ := blake2s.New256([]byte{})
+	return h
+}
+
+func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) {
+	var k1 [32]byte
+	var k2 [32]byte
+	var k3 [32]byte
+	output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{})
+	io.ReadFull(output, k1[:])
+	io.ReadFull(output, k2[:])
+	io.ReadFull(output, k3[:])
+	return k1, k2, k3
+}
+
+/* ---------------------------------------------------------------- *
+ * STATE MANAGEMENT                                                 *
+ * ---------------------------------------------------------------- */
+
+/* CipherState */
+func initializeKey(k [32]byte) cipherstate {
+	return cipherstate{k, minNonce}
+}
+
+func hasKey(cs *cipherstate) bool {
+	return !isEmptyKey(cs.k)
+}
+
+func setNonce(cs *cipherstate, newNonce uint32) *cipherstate {
+	cs.n = newNonce
+	return cs
+}
+
+func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) {
+	e := encrypt(cs.k, cs.n, ad, plaintext)
+	cs = setNonce(cs, incrementNonce(cs.n))
+	return cs, e
+}
+
+func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) {
+	valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext)
+	cs = setNonce(cs, incrementNonce(cs.n))
+	return cs, plaintext, valid
+}
+
+func reKey(cs *cipherstate) *cipherstate {
+	e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:])
+	copy(cs.k[:], e)
+	return cs
+}
+
+/* SymmetricState */
+
+func initializeSymmetric(protocolName []byte) symmetricstate {
+	h := hashProtocolName(protocolName)
+	ck := h
+	cs := initializeKey(emptyKey)
+	return symmetricstate{cs, ck, h}
+}
+
+func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate {
+	ck, tempK, _ := getHkdf(ss.ck, ikm[:])
+	ss.cs = initializeKey(tempK)
+	ss.ck = ck
+	return ss
+}
+
+func mixHash(ss *symmetricstate, data []byte) *symmetricstate {
+	ss.h = getHash(ss.h[:], data)
+	return ss
+}
+
+func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate {
+	var tempH [32]byte
+	var tempK [32]byte
+	ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:])
+	ss = mixHash(ss, tempH[:])
+	ss.cs = initializeKey(tempK)
+	return ss
+}
+
+func getHandshakeHash(ss *symmetricstate) [32]byte {
+	return ss.h
+}
+
+func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) {
+	var ciphertext []byte
+	if hasKey(&ss.cs) {
+		_, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext)
+	} else {
+		ciphertext = plaintext
+	}
+	ss = mixHash(ss, ciphertext)
+	return ss, ciphertext
+}
+
+func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) {
+	var plaintext []byte
+	var valid bool
+	if hasKey(&ss.cs) {
+		_, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext)
+	} else {
+		plaintext, valid = ciphertext, true
+	}
+	ss = mixHash(ss, ciphertext)
+	return ss, plaintext, valid
+}
+
+func split(ss *symmetricstate) (cipherstate, cipherstate) {
+	tempK1, tempK2, _ := getHkdf(ss.ck, []byte{})
+	cs1 := initializeKey(tempK1)
+	cs2 := initializeKey(tempK2)
+	return cs1, cs2
+}
+
+/* HandshakeState */
+
+func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
+	var ss symmetricstate
+	var e keypair
+	var re [32]byte
+	name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
+	ss = initializeSymmetric(name)
+	mixHash(&ss, prologue)
+	mixHash(&ss, rs[:])
+	return handshakestate{ss, s, e, rs, re, psk}
+}
+
+func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
+	var ss symmetricstate
+	var e keypair
+	var re [32]byte
+	name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
+	ss = initializeSymmetric(name)
+	mixHash(&ss, prologue)
+	mixHash(&ss, s.public_key[:])
+	return handshakestate{ss, s, e, rs, re, psk}
+}
+
+func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) {
+	ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
+	hs.e = generateKeypair()
+	ne = hs.e.public_key
+	mixHash(&hs.ss, ne[:])
+	/* No PSK, so skipping mixKey */
+	mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
+	spk := make([]byte, len(hs.s.public_key))
+	copy(spk[:], hs.s.public_key[:])
+	_, ns = encryptAndHash(&hs.ss, spk)
+	mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
+	_, ciphertext = encryptAndHash(&hs.ss, payload)
+	messageBuffer := messagebuffer{ne, ns, ciphertext}
+	return hs, messageBuffer
+}
+
+func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) {
+	ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
+	hs.e = generateKeypair()
+	ne = hs.e.public_key
+	mixHash(&hs.ss, ne[:])
+	/* No PSK, so skipping mixKey */
+	mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
+	mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
+	_, ciphertext = encryptAndHash(&hs.ss, payload)
+	messageBuffer := messagebuffer{ne, ns, ciphertext}
+	cs1, cs2 := split(&hs.ss)
+	return hs.ss.h, messageBuffer, cs1, cs2
+}
+
+func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) {
+	ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
+	cs, ciphertext = encryptWithAd(cs, []byte{}, payload)
+	messageBuffer := messagebuffer{ne, ns, ciphertext}
+	return cs, messageBuffer
+}
+
+func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) {
+	valid1 := true
+	if validatePublicKey(message.ne[:]) {
+		hs.re = message.ne
+	}
+	mixHash(&hs.ss, hs.re[:])
+	/* No PSK, so skipping mixKey */
+	mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
+	_, ns, valid1 := decryptAndHash(&hs.ss, message.ns)
+	if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) {
+		copy(hs.rs[:], ns)
+	}
+	mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
+	_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
+	return hs, plaintext, (valid1 && valid2)
+}
+
+func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) {
+	valid1 := true
+	if validatePublicKey(message.ne[:]) {
+		hs.re = message.ne
+	}
+	mixHash(&hs.ss, hs.re[:])
+	/* No PSK, so skipping mixKey */
+	mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
+	mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
+	_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
+	cs1, cs2 := split(&hs.ss)
+	return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2
+}
+
+func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) {
+	/* No encrypted keys */
+	_, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext)
+	return cs, plaintext, valid2
+}
+
+/* ---------------------------------------------------------------- *
+ * PROCESSES                                                        *
+ * ---------------------------------------------------------------- */
+
+func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession {
+	var session noisesession
+	psk := emptyKey
+	if initiator {
+		session.hs = initializeInitiator(prologue, s, rs, psk)
+	} else {
+		session.hs = initializeResponder(prologue, s, rs, psk)
+	}
+	session.i = initiator
+	session.mc = 0
+	return session
+}
+
+func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) {
+	var messageBuffer messagebuffer
+	if session.mc == 0 {
+		_, messageBuffer = writeMessageA(&session.hs, message)
+	}
+	if session.mc == 1 {
+		session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message)
+		session.hs = handshakestate{}
+	}
+	if session.mc > 1 {
+		if session.i {
+			_, messageBuffer = writeMessageRegular(&session.cs1, message)
+		} else {
+			_, messageBuffer = writeMessageRegular(&session.cs2, message)
+		}
+	}
+	session.mc = session.mc + 1
+	return session, messageBuffer
+}
+
+func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) {
+	var plaintext []byte
+	var valid bool
+	if session.mc == 0 {
+		_, plaintext, valid = readMessageA(&session.hs, message)
+	}
+	if session.mc == 1 {
+		session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message)
+		session.hs = handshakestate{}
+	}
+	if session.mc > 1 {
+		if session.i {
+			_, plaintext, valid = readMessageRegular(&session.cs2, message)
+		} else {
+			_, plaintext, valid = readMessageRegular(&session.cs1, message)
+		}
+	}
+	session.mc = session.mc + 1
+	return session, plaintext, valid
+}
+
+func main() {}

+ 6 - 0
scripts/check_license_headers.sh

@@ -38,6 +38,12 @@ for file in $(find $1 -name '*.go' -not -path '*/.git/*'); do
         $1/wgengine/router/ifconfig_windows.go)
             # WireGuard copyright.
         ;;
+		*_string.go)
+			# Generated file from go:generate stringer
+		;;
+		$1/control/noise/noiseexplorer_test.go)
+			# Noiseexplorer.com copyright.
+		;;
         *)
             header="$(head -3 $file)"
             if ! check_file "$header"; then