conn.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package controlbase implements the base transport of the Tailscale
  5. // 2021 control protocol.
  6. //
  7. // The base transport implements Noise IK, instantiated with
  8. // Curve25519, ChaCha20Poly1305 and BLAKE2s.
  9. package controlbase
  10. import (
  11. "crypto/cipher"
  12. "encoding/binary"
  13. "fmt"
  14. "net"
  15. "sync"
  16. "time"
  17. "golang.org/x/crypto/blake2s"
  18. chp "golang.org/x/crypto/chacha20poly1305"
  19. "tailscale.com/types/key"
  20. )
  21. const (
  22. // maxMessageSize is the maximum size of a protocol frame on the
  23. // wire, including header and payload.
  24. maxMessageSize = 4096
  25. // maxCiphertextSize is the maximum amount of ciphertext bytes
  26. // that one protocol frame can carry, after framing.
  27. maxCiphertextSize = maxMessageSize - 3
  28. // maxPlaintextSize is the maximum amount of plaintext bytes that
  29. // one protocol frame can carry, after encryption and framing.
  30. maxPlaintextSize = maxCiphertextSize - chp.Overhead
  31. )
  32. // A Conn is a secured Noise connection. It implements the net.Conn
  33. // interface, with the unusual trait that any write error (including a
  34. // SetWriteDeadline induced i/o timeout) causes all future writes to
  35. // fail.
  36. type Conn struct {
  37. conn net.Conn
  38. version uint16
  39. peer key.MachinePublic
  40. handshakeHash [blake2s.Size]byte
  41. rx rxState
  42. tx txState
  43. }
  44. // rxState is all the Conn state that Read uses.
  45. type rxState struct {
  46. sync.Mutex
  47. cipher cipher.AEAD
  48. nonce nonce
  49. buf *maxMsgBuffer // or nil when reads exhausted
  50. n int // number of valid bytes in buf
  51. next int // offset of next undecrypted packet
  52. plaintext []byte // slice into buf of decrypted bytes
  53. hdrBuf [headerLen]byte // small buffer used when buf is nil
  54. }
  55. // txState is all the Conn state that Write uses.
  56. type txState struct {
  57. sync.Mutex
  58. cipher cipher.AEAD
  59. nonce nonce
  60. err error // records the first partial write error for all future calls
  61. }
  62. // ProtocolVersion returns the protocol version that was used to
  63. // establish this Conn.
  64. func (c *Conn) ProtocolVersion() int {
  65. return int(c.version)
  66. }
  67. // HandshakeHash returns the Noise handshake hash for the connection,
  68. // which can be used to bind other messages to this connection
  69. // (i.e. to ensure that the message wasn't replayed from a different
  70. // connection).
  71. func (c *Conn) HandshakeHash() [blake2s.Size]byte {
  72. return c.handshakeHash
  73. }
  74. // Peer returns the peer's long-term public key.
  75. func (c *Conn) Peer() key.MachinePublic {
  76. return c.peer
  77. }
  78. // readNLocked reads into c.rx.buf until buf contains at least total
  79. // bytes. Returns a slice of the total bytes in rxBuf, or an
  80. // error if fewer than total bytes are available.
  81. //
  82. // It may be called with a nil c.rx.buf only if total == headerLen.
  83. //
  84. // On success, c.rx.buf will be non-nil.
  85. func (c *Conn) readNLocked(total int) ([]byte, error) {
  86. if total > maxMessageSize {
  87. return nil, errReadTooBig{total}
  88. }
  89. for {
  90. if total <= c.rx.n {
  91. return c.rx.buf[:total], nil
  92. }
  93. var n int
  94. var err error
  95. if c.rx.buf == nil {
  96. if c.rx.n != 0 || total != headerLen {
  97. panic("unexpected")
  98. }
  99. // Optimization to reduce memory usage.
  100. // Most connections are blocked forever waiting for
  101. // a read, so we don't want c.rx.buf to be allocated until
  102. // we know there's data to read. Instead, when we're
  103. // waiting for data to arrive here, read into the
  104. // 3 byte hdrBuf:
  105. n, err = c.conn.Read(c.rx.hdrBuf[:])
  106. if n > 0 {
  107. c.rx.buf = getMaxMsgBuffer()
  108. copy(c.rx.buf[:], c.rx.hdrBuf[:n])
  109. }
  110. } else {
  111. n, err = c.conn.Read(c.rx.buf[c.rx.n:])
  112. }
  113. c.rx.n += n
  114. if err != nil {
  115. return nil, err
  116. }
  117. }
  118. }
  119. // decryptLocked decrypts msg (which is header+ciphertext) in-place
  120. // and sets c.rx.plaintext to the decrypted bytes.
  121. func (c *Conn) decryptLocked(msg []byte) (err error) {
  122. if msgType := msg[0]; msgType != msgTypeRecord {
  123. return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
  124. }
  125. // We don't check the length field here, because the caller
  126. // already did in order to figure out how big the msg slice should
  127. // be.
  128. ciphertext := msg[headerLen:]
  129. if !c.rx.nonce.Valid() {
  130. return errCipherExhausted{}
  131. }
  132. c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
  133. c.rx.nonce.Increment()
  134. if err != nil {
  135. // Once a decryption has failed, our Conn is no longer
  136. // synchronized with our peer. Nuke the cipher state to be
  137. // safe, so that no further decryptions are attempted. Future
  138. // read attempts will return net.ErrClosed.
  139. c.rx.cipher = nil
  140. }
  141. return err
  142. }
  143. // encryptLocked encrypts plaintext into buf (including the
  144. // packet header) and returns a slice of the ciphertext, or an error
  145. // if the cipher is exhausted (i.e. can no longer be used safely).
  146. func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) {
  147. if !c.tx.nonce.Valid() {
  148. // Received 2^64-1 messages on this cipher state. Connection
  149. // is no longer usable.
  150. return nil, errCipherExhausted{}
  151. }
  152. buf[0] = msgTypeRecord
  153. binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
  154. ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil)
  155. c.tx.nonce.Increment()
  156. return ret, nil
  157. }
  158. // wholeMessageLocked returns a slice of one whole Noise transport
  159. // message from c.rx.buf, if one whole message is available, and
  160. // advances the read state to the next Noise message in the
  161. // buffer. Returns nil without advancing read state if there isn't one
  162. // whole message in c.rx.buf.
  163. func (c *Conn) wholeMessageLocked() []byte {
  164. available := c.rx.n - c.rx.next
  165. if available < headerLen {
  166. return nil
  167. }
  168. bs := c.rx.buf[c.rx.next:c.rx.n]
  169. totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
  170. if len(bs) < totalSize {
  171. return nil
  172. }
  173. c.rx.next += totalSize
  174. return bs[:totalSize]
  175. }
  176. // decryptOneLocked decrypts one Noise transport message, reading from
  177. // c.conn as needed, and sets c.rx.plaintext to point to the decrypted
  178. // bytes. c.rx.plaintext is only valid if err == nil.
  179. func (c *Conn) decryptOneLocked() error {
  180. c.rx.plaintext = nil
  181. // Fast path: do we have one whole ciphertext frame buffered
  182. // already?
  183. if bs := c.wholeMessageLocked(); bs != nil {
  184. return c.decryptLocked(bs)
  185. }
  186. if c.rx.next != 0 {
  187. // To simplify the read logic, move the remainder of the
  188. // buffered bytes back to the head of the buffer, so we can
  189. // grow it without worrying about wraparound.
  190. c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
  191. c.rx.next = 0
  192. }
  193. // Return our buffer to the pool if it's empty, lest we be
  194. // blocked in a long Read call, reading the 3 byte header. We
  195. // don't to keep that buffer unnecessarily alive.
  196. if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil {
  197. bufPool.Put(c.rx.buf)
  198. c.rx.buf = nil
  199. }
  200. bs, err := c.readNLocked(headerLen)
  201. if err != nil {
  202. return err
  203. }
  204. // The rest of the header (besides the length field) gets verified
  205. // in decryptLocked, not here.
  206. messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
  207. bs, err = c.readNLocked(messageLen)
  208. if err != nil {
  209. return err
  210. }
  211. c.rx.next = len(bs)
  212. return c.decryptLocked(bs)
  213. }
  214. // Read implements io.Reader.
  215. func (c *Conn) Read(bs []byte) (int, error) {
  216. c.rx.Lock()
  217. defer c.rx.Unlock()
  218. if c.rx.cipher == nil {
  219. return 0, net.ErrClosed
  220. }
  221. // If no plaintext is buffered, decrypt incoming frames until we
  222. // have some plaintext. Zero-byte Noise frames are allowed in this
  223. // protocol, which is why we have to loop here rather than decrypt
  224. // a single additional frame.
  225. for len(c.rx.plaintext) == 0 {
  226. if err := c.decryptOneLocked(); err != nil {
  227. return 0, err
  228. }
  229. }
  230. n := copy(bs, c.rx.plaintext)
  231. c.rx.plaintext = c.rx.plaintext[n:]
  232. // Lose slice's underlying array pointer to unneeded memory so
  233. // GC can collect more.
  234. if len(c.rx.plaintext) == 0 {
  235. c.rx.plaintext = nil
  236. }
  237. return n, nil
  238. }
  239. // Write implements io.Writer.
  240. func (c *Conn) Write(bs []byte) (n int, err error) {
  241. c.tx.Lock()
  242. defer c.tx.Unlock()
  243. if c.tx.err != nil {
  244. return 0, c.tx.err
  245. }
  246. defer func() {
  247. if err != nil {
  248. // All write errors are fatal for this conn, so clear the
  249. // cipher state whenever an error happens.
  250. c.tx.cipher = nil
  251. }
  252. if c.tx.err == nil {
  253. // Only set c.tx.err if not nil so that we can return one
  254. // error on the first failure, and a different one for
  255. // subsequent calls. See the error handling around Write
  256. // below for why.
  257. c.tx.err = err
  258. }
  259. }()
  260. if c.tx.cipher == nil {
  261. return 0, net.ErrClosed
  262. }
  263. buf := getMaxMsgBuffer()
  264. defer bufPool.Put(buf)
  265. var sent int
  266. for len(bs) > 0 {
  267. toSend := bs
  268. if len(toSend) > maxPlaintextSize {
  269. toSend = bs[:maxPlaintextSize]
  270. }
  271. bs = bs[len(toSend):]
  272. ciphertext, err := c.encryptLocked(toSend, buf)
  273. if err != nil {
  274. return sent, err
  275. }
  276. if _, err := c.conn.Write(ciphertext); err != nil {
  277. // Return the raw error on the Write that actually
  278. // failed. For future writes, return that error wrapped in
  279. // a desync error.
  280. c.tx.err = errPartialWrite{err}
  281. return sent, err
  282. }
  283. sent += len(toSend)
  284. }
  285. return sent, nil
  286. }
  287. // Close implements io.Closer.
  288. func (c *Conn) Close() error {
  289. closeErr := c.conn.Close() // unblocks any waiting reads or writes
  290. // Remove references to live cipher state. Strictly speaking this
  291. // is unnecessary, but we want to try and hand the active cipher
  292. // state to the garbage collector promptly, to preserve perfect
  293. // forward secrecy as much as we can.
  294. c.rx.Lock()
  295. c.rx.cipher = nil
  296. c.rx.Unlock()
  297. c.tx.Lock()
  298. c.tx.cipher = nil
  299. c.tx.Unlock()
  300. return closeErr
  301. }
  302. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
  303. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
  304. func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
  305. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
  306. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
  307. // errCipherExhausted is the error returned when we run out of nonces
  308. // on a cipher.
  309. type errCipherExhausted struct{}
  310. func (errCipherExhausted) Error() string {
  311. return "cipher exhausted, no more nonces available for current key"
  312. }
  313. func (errCipherExhausted) Timeout() bool { return false }
  314. func (errCipherExhausted) Temporary() bool { return false }
  315. // errPartialWrite is the error returned when the cipher state has
  316. // become unusable due to a past partial write.
  317. type errPartialWrite struct {
  318. err error
  319. }
  320. func (e errPartialWrite) Error() string {
  321. return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
  322. }
  323. func (e errPartialWrite) Unwrap() error { return e.err }
  324. func (e errPartialWrite) Temporary() bool { return false }
  325. func (e errPartialWrite) Timeout() bool { return false }
  326. // errReadTooBig is the error returned when the peer sent an
  327. // unacceptably large Noise frame.
  328. type errReadTooBig struct {
  329. requested int
  330. }
  331. func (e errReadTooBig) Error() string {
  332. return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
  333. }
  334. func (e errReadTooBig) Temporary() bool {
  335. // permanent error because this error only occurs when our peer
  336. // sends us a frame so large we're unwilling to ever decode it.
  337. return false
  338. }
  339. func (e errReadTooBig) Timeout() bool { return false }
  340. type nonce [chp.NonceSize]byte
  341. func (n *nonce) Valid() bool {
  342. return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
  343. }
  344. func (n *nonce) Increment() {
  345. if !n.Valid() {
  346. panic("increment of invalid nonce")
  347. }
  348. binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
  349. }
  350. type maxMsgBuffer [maxMessageSize]byte
  351. // bufPool holds the temporary buffers for Conn.Read & Write.
  352. var bufPool = &sync.Pool{
  353. New: func() interface{} {
  354. return new(maxMsgBuffer)
  355. },
  356. }
  357. func getMaxMsgBuffer() *maxMsgBuffer {
  358. return bufPool.Get().(*maxMsgBuffer)
  359. }