| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- // Package controlbase implements the base transport of the Tailscale
- // 2021 control protocol.
- //
- // The base transport implements Noise IK, instantiated with
- // Curve25519, ChaCha20Poly1305 and BLAKE2s.
- package controlbase
- import (
- "crypto/cipher"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "time"
- "golang.org/x/crypto/blake2s"
- chp "golang.org/x/crypto/chacha20poly1305"
- "tailscale.com/types/key"
- )
- const (
- // maxMessageSize is the maximum size of a protocol frame on the
- // wire, including header and payload.
- maxMessageSize = 4096
- // maxCiphertextSize is the maximum amount of ciphertext bytes
- // that one protocol frame can carry, after framing.
- maxCiphertextSize = maxMessageSize - 3
- // maxPlaintextSize is the maximum amount of plaintext bytes that
- // one protocol frame can carry, after encryption and framing.
- maxPlaintextSize = maxCiphertextSize - chp.Overhead
- )
- // A Conn is a secured Noise connection. It implements the net.Conn
- // interface, with the unusual trait that any write error (including a
- // SetWriteDeadline induced i/o timeout) causes all future writes to
- // fail.
- type Conn struct {
- conn net.Conn
- version uint16
- peer key.MachinePublic
- handshakeHash [blake2s.Size]byte
- rx rxState
- tx txState
- }
- // rxState is all the Conn state that Read uses.
- type rxState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- buf *maxMsgBuffer // or nil when reads exhausted
- n int // number of valid bytes in buf
- next int // offset of next undecrypted packet
- plaintext []byte // slice into buf of decrypted bytes
- hdrBuf [headerLen]byte // small buffer used when buf is nil
- }
- // txState is all the Conn state that Write uses.
- type txState struct {
- sync.Mutex
- cipher cipher.AEAD
- nonce nonce
- err error // records the first partial write error for all future calls
- }
- // ProtocolVersion returns the protocol version that was used to
- // establish this Conn.
- func (c *Conn) ProtocolVersion() int {
- return int(c.version)
- }
- // HandshakeHash returns the Noise handshake hash for the connection,
- // which can be used to bind other messages to this connection
- // (i.e. to ensure that the message wasn't replayed from a different
- // connection).
- func (c *Conn) HandshakeHash() [blake2s.Size]byte {
- return c.handshakeHash
- }
- // Peer returns the peer's long-term public key.
- func (c *Conn) Peer() key.MachinePublic {
- return c.peer
- }
- // readNLocked reads into c.rx.buf until buf contains at least total
- // bytes. Returns a slice of the total bytes in rxBuf, or an
- // error if fewer than total bytes are available.
- //
- // It may be called with a nil c.rx.buf only if total == headerLen.
- //
- // On success, c.rx.buf will be non-nil.
- func (c *Conn) readNLocked(total int) ([]byte, error) {
- if total > maxMessageSize {
- return nil, errReadTooBig{total}
- }
- for {
- if total <= c.rx.n {
- return c.rx.buf[:total], nil
- }
- var n int
- var err error
- if c.rx.buf == nil {
- if c.rx.n != 0 || total != headerLen {
- panic("unexpected")
- }
- // Optimization to reduce memory usage.
- // Most connections are blocked forever waiting for
- // a read, so we don't want c.rx.buf to be allocated until
- // we know there's data to read. Instead, when we're
- // waiting for data to arrive here, read into the
- // 3 byte hdrBuf:
- n, err = c.conn.Read(c.rx.hdrBuf[:])
- if n > 0 {
- c.rx.buf = getMaxMsgBuffer()
- copy(c.rx.buf[:], c.rx.hdrBuf[:n])
- }
- } else {
- n, err = c.conn.Read(c.rx.buf[c.rx.n:])
- }
- c.rx.n += n
- if err != nil {
- return nil, err
- }
- }
- }
- // decryptLocked decrypts msg (which is header+ciphertext) in-place
- // and sets c.rx.plaintext to the decrypted bytes.
- func (c *Conn) decryptLocked(msg []byte) (err error) {
- if msgType := msg[0]; msgType != msgTypeRecord {
- return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
- }
- // We don't check the length field here, because the caller
- // already did in order to figure out how big the msg slice should
- // be.
- ciphertext := msg[headerLen:]
- if !c.rx.nonce.Valid() {
- return errCipherExhausted{}
- }
- c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
- c.rx.nonce.Increment()
- if err != nil {
- // Once a decryption has failed, our Conn is no longer
- // synchronized with our peer. Nuke the cipher state to be
- // safe, so that no further decryptions are attempted. Future
- // read attempts will return net.ErrClosed.
- c.rx.cipher = nil
- }
- return err
- }
- // encryptLocked encrypts plaintext into buf (including the
- // packet header) and returns a slice of the ciphertext, or an error
- // if the cipher is exhausted (i.e. can no longer be used safely).
- func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
- if !c.tx.nonce.Valid() {
- // Received 2^64-1 messages on this cipher state. Connection
- // is no longer usable.
- return nil, errCipherExhausted{}
- }
- buf[0] = msgTypeRecord
- binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
- ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
- c.tx.nonce.Increment()
- return ret, nil
- }
- // wholeMessageLocked returns a slice of one whole Noise transport
- // message from c.rx.buf, if one whole message is available, and
- // advances the read state to the next Noise message in the
- // buffer. Returns nil without advancing read state if there isn't one
- // whole message in c.rx.buf.
- func (c *Conn) wholeMessageLocked() []byte {
- available := c.rx.n - c.rx.next
- if available < headerLen {
- return nil
- }
- bs := c.rx.buf[c.rx.next:c.rx.n]
- totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- if len(bs) < totalSize {
- return nil
- }
- c.rx.next += totalSize
- return bs[:totalSize]
- }
- // decryptOneLocked decrypts one Noise transport message, reading from
- // c.conn as needed, and sets c.rx.plaintext to point to the decrypted
- // bytes. c.rx.plaintext is only valid if err == nil.
- func (c *Conn) decryptOneLocked() error {
- c.rx.plaintext = nil
- // Fast path: do we have one whole ciphertext frame buffered
- // already?
- if bs := c.wholeMessageLocked(); bs != nil {
- return c.decryptLocked(bs)
- }
- if c.rx.next != 0 {
- // To simplify the read logic, move the remainder of the
- // buffered bytes back to the head of the buffer, so we can
- // grow it without worrying about wraparound.
- c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
- c.rx.next = 0
- }
- // Return our buffer to the pool if it's empty, lest we be
- // blocked in a long Read call, reading the 3 byte header. We
- // don't to keep that buffer unnecessarily alive.
- if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
- bufPool.Put(c.rx.buf)
- c.rx.buf = nil
- }
- bs, err := c.readNLocked(headerLen)
- if err != nil {
- return err
- }
- // The rest of the header (besides the length field) gets verified
- // in decryptLocked, not here.
- messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
- bs, err = c.readNLocked(messageLen)
- if err != nil {
- return err
- }
- c.rx.next = len(bs)
- return c.decryptLocked(bs)
- }
- // Read implements io.Reader.
- func (c *Conn) Read(bs []byte) (int, error) {
- c.rx.Lock()
- defer c.rx.Unlock()
- if c.rx.cipher == nil {
- return 0, net.ErrClosed
- }
- // If no plaintext is buffered, decrypt incoming frames until we
- // have some plaintext. Zero-byte Noise frames are allowed in this
- // protocol, which is why we have to loop here rather than decrypt
- // a single additional frame.
- for len(c.rx.plaintext) == 0 {
- if err := c.decryptOneLocked(); err != nil {
- return 0, err
- }
- }
- n := copy(bs, c.rx.plaintext)
- c.rx.plaintext = c.rx.plaintext[n:]
- // Lose slice's underlying array pointer to unneeded memory so
- // GC can collect more.
- if len(c.rx.plaintext) == 0 {
- c.rx.plaintext = nil
- }
- return n, nil
- }
- // Write implements io.Writer.
- func (c *Conn) Write(bs []byte) (n int, err error) {
- c.tx.Lock()
- defer c.tx.Unlock()
- if c.tx.err != nil {
- return 0, c.tx.err
- }
- defer func() {
- if err != nil {
- // All write errors are fatal for this conn, so clear the
- // cipher state whenever an error happens.
- c.tx.cipher = nil
- }
- if c.tx.err == nil {
- // Only set c.tx.err if not nil so that we can return one
- // error on the first failure, and a different one for
- // subsequent calls. See the error handling around Write
- // below for why.
- c.tx.err = err
- }
- }()
- if c.tx.cipher == nil {
- return 0, net.ErrClosed
- }
- buf := getMaxMsgBuffer()
- defer bufPool.Put(buf)
- var sent int
- for len(bs) > 0 {
- toSend := bs
- if len(toSend) > maxPlaintextSize {
- toSend = bs[:maxPlaintextSize]
- }
- bs = bs[len(toSend):]
- ciphertext, err := c.encryptLocked(toSend, buf)
- if err != nil {
- return sent, err
- }
- if _, err := c.conn.Write(ciphertext); err != nil {
- // Return the raw error on the Write that actually
- // failed. For future writes, return that error wrapped in
- // a desync error.
- c.tx.err = errPartialWrite{err}
- return sent, err
- }
- sent += len(toSend)
- }
- return sent, nil
- }
- // Close implements io.Closer.
- func (c *Conn) Close() error {
- closeErr := c.conn.Close() // unblocks any waiting reads or writes
- // Remove references to live cipher state. Strictly speaking this
- // is unnecessary, but we want to try and hand the active cipher
- // state to the garbage collector promptly, to preserve perfect
- // forward secrecy as much as we can.
- c.rx.Lock()
- c.rx.cipher = nil
- c.rx.Unlock()
- c.tx.Lock()
- c.tx.cipher = nil
- c.tx.Unlock()
- return closeErr
- }
- func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
- func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
- func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
- func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
- func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
- // errCipherExhausted is the error returned when we run out of nonces
- // on a cipher.
- type errCipherExhausted struct{}
- func (errCipherExhausted) Error() string {
- return "cipher exhausted, no more nonces available for current key"
- }
- func (errCipherExhausted) Timeout() bool { return false }
- func (errCipherExhausted) Temporary() bool { return false }
- // errPartialWrite is the error returned when the cipher state has
- // become unusable due to a past partial write.
- type errPartialWrite struct {
- err error
- }
- func (e errPartialWrite) Error() string {
- return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
- }
- func (e errPartialWrite) Unwrap() error { return e.err }
- func (e errPartialWrite) Temporary() bool { return false }
- func (e errPartialWrite) Timeout() bool { return false }
- // errReadTooBig is the error returned when the peer sent an
- // unacceptably large Noise frame.
- type errReadTooBig struct {
- requested int
- }
- func (e errReadTooBig) Error() string {
- return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
- }
- func (e errReadTooBig) Temporary() bool {
- // permanent error because this error only occurs when our peer
- // sends us a frame so large we're unwilling to ever decode it.
- return false
- }
- func (e errReadTooBig) Timeout() bool { return false }
- type nonce [chp.NonceSize]byte
- func (n *nonce) Valid() bool {
- return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
- }
- func (n *nonce) Increment() {
- if !n.Valid() {
- panic("increment of invalid nonce")
- }
- binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
- }
- type maxMsgBuffer [maxMessageSize]byte
- // bufPool holds the temporary buffers for Conn.Read & Write.
- var bufPool = &sync.Pool{
- New: func() any {
- return new(maxMsgBuffer)
- },
- }
- func getMaxMsgBuffer() *maxMsgBuffer {
- return bufPool.Get().(*maxMsgBuffer)
- }
|