conn.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Package controlbase implements the base transport of the Tailscale
  4. // 2021 control protocol.
  5. //
  6. // The base transport implements Noise IK, instantiated with
  7. // Curve25519, ChaCha20Poly1305 and BLAKE2s.
  8. package controlbase
  9. import (
  10. "crypto/cipher"
  11. "encoding/binary"
  12. "fmt"
  13. "net"
  14. "sync"
  15. "time"
  16. "golang.org/x/crypto/blake2s"
  17. chp "golang.org/x/crypto/chacha20poly1305"
  18. "tailscale.com/types/key"
  19. )
  20. const (
  21. // maxMessageSize is the maximum size of a protocol frame on the
  22. // wire, including header and payload.
  23. maxMessageSize = 4096
  24. // maxCiphertextSize is the maximum amount of ciphertext bytes
  25. // that one protocol frame can carry, after framing.
  26. maxCiphertextSize = maxMessageSize - 3
  27. // maxPlaintextSize is the maximum amount of plaintext bytes that
  28. // one protocol frame can carry, after encryption and framing.
  29. maxPlaintextSize = maxCiphertextSize - chp.Overhead
  30. )
  31. // A Conn is a secured Noise connection. It implements the net.Conn
  32. // interface, with the unusual trait that any write error (including a
  33. // SetWriteDeadline induced i/o timeout) causes all future writes to
  34. // fail.
  35. type Conn struct {
  36. conn net.Conn
  37. version uint16
  38. peer key.MachinePublic
  39. handshakeHash [blake2s.Size]byte
  40. rx rxState
  41. tx txState
  42. }
  43. // rxState is all the Conn state that Read uses.
  44. type rxState struct {
  45. sync.Mutex
  46. cipher cipher.AEAD
  47. nonce nonce
  48. buf *maxMsgBuffer // or nil when reads exhausted
  49. n int // number of valid bytes in buf
  50. next int // offset of next undecrypted packet
  51. plaintext []byte // slice into buf of decrypted bytes
  52. hdrBuf [headerLen]byte // small buffer used when buf is nil
  53. }
  54. // txState is all the Conn state that Write uses.
  55. type txState struct {
  56. sync.Mutex
  57. cipher cipher.AEAD
  58. nonce nonce
  59. err error // records the first partial write error for all future calls
  60. }
  61. // ProtocolVersion returns the protocol version that was used to
  62. // establish this Conn.
  63. func (c *Conn) ProtocolVersion() int {
  64. return int(c.version)
  65. }
  66. // HandshakeHash returns the Noise handshake hash for the connection,
  67. // which can be used to bind other messages to this connection
  68. // (i.e. to ensure that the message wasn't replayed from a different
  69. // connection).
  70. func (c *Conn) HandshakeHash() [blake2s.Size]byte {
  71. return c.handshakeHash
  72. }
  73. // Peer returns the peer's long-term public key.
  74. func (c *Conn) Peer() key.MachinePublic {
  75. return c.peer
  76. }
  77. // readNLocked reads into c.rx.buf until buf contains at least total
  78. // bytes. Returns a slice of the total bytes in rxBuf, or an
  79. // error if fewer than total bytes are available.
  80. //
  81. // It may be called with a nil c.rx.buf only if total == headerLen.
  82. //
  83. // On success, c.rx.buf will be non-nil.
  84. func (c *Conn) readNLocked(total int) ([]byte, error) {
  85. if total > maxMessageSize {
  86. return nil, errReadTooBig{total}
  87. }
  88. for {
  89. if total <= c.rx.n {
  90. return c.rx.buf[:total], nil
  91. }
  92. var n int
  93. var err error
  94. if c.rx.buf == nil {
  95. if c.rx.n != 0 || total != headerLen {
  96. panic("unexpected")
  97. }
  98. // Optimization to reduce memory usage.
  99. // Most connections are blocked forever waiting for
  100. // a read, so we don't want c.rx.buf to be allocated until
  101. // we know there's data to read. Instead, when we're
  102. // waiting for data to arrive here, read into the
  103. // 3 byte hdrBuf:
  104. n, err = c.conn.Read(c.rx.hdrBuf[:])
  105. if n > 0 {
  106. c.rx.buf = getMaxMsgBuffer()
  107. copy(c.rx.buf[:], c.rx.hdrBuf[:n])
  108. }
  109. } else {
  110. n, err = c.conn.Read(c.rx.buf[c.rx.n:])
  111. }
  112. c.rx.n += n
  113. if err != nil {
  114. return nil, err
  115. }
  116. }
  117. }
  118. // decryptLocked decrypts msg (which is header+ciphertext) in-place
  119. // and sets c.rx.plaintext to the decrypted bytes.
  120. func (c *Conn) decryptLocked(msg []byte) (err error) {
  121. if msgType := msg[0]; msgType != msgTypeRecord {
  122. return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
  123. }
  124. // We don't check the length field here, because the caller
  125. // already did in order to figure out how big the msg slice should
  126. // be.
  127. ciphertext := msg[headerLen:]
  128. if !c.rx.nonce.Valid() {
  129. return errCipherExhausted{}
  130. }
  131. c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
  132. c.rx.nonce.Increment()
  133. if err != nil {
  134. // Once a decryption has failed, our Conn is no longer
  135. // synchronized with our peer. Nuke the cipher state to be
  136. // safe, so that no further decryptions are attempted. Future
  137. // read attempts will return net.ErrClosed.
  138. c.rx.cipher = nil
  139. }
  140. return err
  141. }
  142. // encryptLocked encrypts plaintext into buf (including the
  143. // packet header) and returns a slice of the ciphertext, or an error
  144. // if the cipher is exhausted (i.e. can no longer be used safely).
  145. func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
  146. if !c.tx.nonce.Valid() {
  147. // Received 2^64-1 messages on this cipher state. Connection
  148. // is no longer usable.
  149. return nil, errCipherExhausted{}
  150. }
  151. buf[0] = msgTypeRecord
  152. binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
  153. ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
  154. c.tx.nonce.Increment()
  155. return ret, nil
  156. }
  157. // wholeMessageLocked returns a slice of one whole Noise transport
  158. // message from c.rx.buf, if one whole message is available, and
  159. // advances the read state to the next Noise message in the
  160. // buffer. Returns nil without advancing read state if there isn't one
  161. // whole message in c.rx.buf.
  162. func (c *Conn) wholeMessageLocked() []byte {
  163. available := c.rx.n - c.rx.next
  164. if available < headerLen {
  165. return nil
  166. }
  167. bs := c.rx.buf[c.rx.next:c.rx.n]
  168. totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
  169. if len(bs) < totalSize {
  170. return nil
  171. }
  172. c.rx.next += totalSize
  173. return bs[:totalSize]
  174. }
  175. // decryptOneLocked decrypts one Noise transport message, reading from
  176. // c.conn as needed, and sets c.rx.plaintext to point to the decrypted
  177. // bytes. c.rx.plaintext is only valid if err == nil.
  178. func (c *Conn) decryptOneLocked() error {
  179. c.rx.plaintext = nil
  180. // Fast path: do we have one whole ciphertext frame buffered
  181. // already?
  182. if bs := c.wholeMessageLocked(); bs != nil {
  183. return c.decryptLocked(bs)
  184. }
  185. if c.rx.next != 0 {
  186. // To simplify the read logic, move the remainder of the
  187. // buffered bytes back to the head of the buffer, so we can
  188. // grow it without worrying about wraparound.
  189. c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
  190. c.rx.next = 0
  191. }
  192. // Return our buffer to the pool if it's empty, lest we be
  193. // blocked in a long Read call, reading the 3 byte header. We
  194. // don't to keep that buffer unnecessarily alive.
  195. if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
  196. bufPool.Put(c.rx.buf)
  197. c.rx.buf = nil
  198. }
  199. bs, err := c.readNLocked(headerLen)
  200. if err != nil {
  201. return err
  202. }
  203. // The rest of the header (besides the length field) gets verified
  204. // in decryptLocked, not here.
  205. messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
  206. bs, err = c.readNLocked(messageLen)
  207. if err != nil {
  208. return err
  209. }
  210. c.rx.next = len(bs)
  211. return c.decryptLocked(bs)
  212. }
  213. // Read implements io.Reader.
  214. func (c *Conn) Read(bs []byte) (int, error) {
  215. c.rx.Lock()
  216. defer c.rx.Unlock()
  217. if c.rx.cipher == nil {
  218. return 0, net.ErrClosed
  219. }
  220. // If no plaintext is buffered, decrypt incoming frames until we
  221. // have some plaintext. Zero-byte Noise frames are allowed in this
  222. // protocol, which is why we have to loop here rather than decrypt
  223. // a single additional frame.
  224. for len(c.rx.plaintext) == 0 {
  225. if err := c.decryptOneLocked(); err != nil {
  226. return 0, err
  227. }
  228. }
  229. n := copy(bs, c.rx.plaintext)
  230. c.rx.plaintext = c.rx.plaintext[n:]
  231. // Lose slice's underlying array pointer to unneeded memory so
  232. // GC can collect more.
  233. if len(c.rx.plaintext) == 0 {
  234. c.rx.plaintext = nil
  235. }
  236. return n, nil
  237. }
  238. // Write implements io.Writer.
  239. func (c *Conn) Write(bs []byte) (n int, err error) {
  240. c.tx.Lock()
  241. defer c.tx.Unlock()
  242. if c.tx.err != nil {
  243. return 0, c.tx.err
  244. }
  245. defer func() {
  246. if err != nil {
  247. // All write errors are fatal for this conn, so clear the
  248. // cipher state whenever an error happens.
  249. c.tx.cipher = nil
  250. }
  251. if c.tx.err == nil {
  252. // Only set c.tx.err if not nil so that we can return one
  253. // error on the first failure, and a different one for
  254. // subsequent calls. See the error handling around Write
  255. // below for why.
  256. c.tx.err = err
  257. }
  258. }()
  259. if c.tx.cipher == nil {
  260. return 0, net.ErrClosed
  261. }
  262. buf := getMaxMsgBuffer()
  263. defer bufPool.Put(buf)
  264. var sent int
  265. for len(bs) > 0 {
  266. toSend := bs
  267. if len(toSend) > maxPlaintextSize {
  268. toSend = bs[:maxPlaintextSize]
  269. }
  270. bs = bs[len(toSend):]
  271. ciphertext, err := c.encryptLocked(toSend, buf)
  272. if err != nil {
  273. return sent, err
  274. }
  275. if _, err := c.conn.Write(ciphertext); err != nil {
  276. // Return the raw error on the Write that actually
  277. // failed. For future writes, return that error wrapped in
  278. // a desync error.
  279. c.tx.err = errPartialWrite{err}
  280. return sent, err
  281. }
  282. sent += len(toSend)
  283. }
  284. return sent, nil
  285. }
  286. // Close implements io.Closer.
  287. func (c *Conn) Close() error {
  288. closeErr := c.conn.Close() // unblocks any waiting reads or writes
  289. // Remove references to live cipher state. Strictly speaking this
  290. // is unnecessary, but we want to try and hand the active cipher
  291. // state to the garbage collector promptly, to preserve perfect
  292. // forward secrecy as much as we can.
  293. c.rx.Lock()
  294. c.rx.cipher = nil
  295. c.rx.Unlock()
  296. c.tx.Lock()
  297. c.tx.cipher = nil
  298. c.tx.Unlock()
  299. return closeErr
  300. }
  301. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
  302. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
  303. func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
  304. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
  305. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
  306. // errCipherExhausted is the error returned when we run out of nonces
  307. // on a cipher.
  308. type errCipherExhausted struct{}
  309. func (errCipherExhausted) Error() string {
  310. return "cipher exhausted, no more nonces available for current key"
  311. }
  312. func (errCipherExhausted) Timeout() bool { return false }
  313. func (errCipherExhausted) Temporary() bool { return false }
  314. // errPartialWrite is the error returned when the cipher state has
  315. // become unusable due to a past partial write.
  316. type errPartialWrite struct {
  317. err error
  318. }
  319. func (e errPartialWrite) Error() string {
  320. return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
  321. }
  322. func (e errPartialWrite) Unwrap() error { return e.err }
  323. func (e errPartialWrite) Temporary() bool { return false }
  324. func (e errPartialWrite) Timeout() bool { return false }
  325. // errReadTooBig is the error returned when the peer sent an
  326. // unacceptably large Noise frame.
  327. type errReadTooBig struct {
  328. requested int
  329. }
  330. func (e errReadTooBig) Error() string {
  331. return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
  332. }
  333. func (e errReadTooBig) Temporary() bool {
  334. // permanent error because this error only occurs when our peer
  335. // sends us a frame so large we're unwilling to ever decode it.
  336. return false
  337. }
  338. func (e errReadTooBig) Timeout() bool { return false }
  339. type nonce [chp.NonceSize]byte
  340. func (n *nonce) Valid() bool {
  341. return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
  342. }
  343. func (n *nonce) Increment() {
  344. if !n.Valid() {
  345. panic("increment of invalid nonce")
  346. }
  347. binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
  348. }
  349. type maxMsgBuffer [maxMessageSize]byte
  350. // bufPool holds the temporary buffers for Conn.Read & Write.
  351. var bufPool = &sync.Pool{
  352. New: func() any {
  353. return new(maxMsgBuffer)
  354. },
  355. }
  356. func getMaxMsgBuffer() *maxMsgBuffer {
  357. return bufPool.Get().(*maxMsgBuffer)
  358. }