conn.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979
  1. package dtls
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/pion/dtls/v2/internal/closer"
  12. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  13. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  14. "github.com/pion/dtls/v2/pkg/protocol"
  15. "github.com/pion/dtls/v2/pkg/protocol/alert"
  16. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  17. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  18. "github.com/pion/logging"
  19. "github.com/pion/transport/connctx"
  20. "github.com/pion/transport/deadline"
  21. "github.com/pion/transport/replaydetector"
  22. )
  23. const (
  24. initialTickerInterval = time.Second
  25. cookieLength = 20
  26. defaultNamedCurve = elliptic.X25519
  27. inboundBufferSize = 8192
  28. // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
  29. defaultReplayProtectionWindow = 64
  30. )
  31. func invalidKeyingLabels() map[string]bool {
  32. return map[string]bool{
  33. "client finished": true,
  34. "server finished": true,
  35. "master secret": true,
  36. "key expansion": true,
  37. }
  38. }
  39. // Conn represents a DTLS connection
  40. type Conn struct {
  41. lock sync.RWMutex // Internal lock (must not be public)
  42. nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
  43. fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
  44. handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
  45. decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
  46. state State // Internal state
  47. maximumTransmissionUnit int
  48. handshakeCompletedSuccessfully atomic.Value
  49. encryptedPackets [][]byte
  50. connectionClosedByUser bool
  51. closeLock sync.Mutex
  52. closed *closer.Closer
  53. handshakeLoopsFinished sync.WaitGroup
  54. readDeadline *deadline.Deadline
  55. writeDeadline *deadline.Deadline
  56. log logging.LeveledLogger
  57. reading chan struct{}
  58. handshakeRecv chan chan struct{}
  59. cancelHandshaker func()
  60. cancelHandshakeReader func()
  61. fsm *handshakeFSM
  62. replayProtectionWindow uint
  63. }
  64. func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
  65. err := validateConfig(config)
  66. if err != nil {
  67. return nil, err
  68. }
  69. if nextConn == nil {
  70. return nil, errNilNextConn
  71. }
  72. cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
  73. if err != nil {
  74. return nil, err
  75. }
  76. signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
  77. if err != nil {
  78. return nil, err
  79. }
  80. workerInterval := initialTickerInterval
  81. if config.FlightInterval != 0 {
  82. workerInterval = config.FlightInterval
  83. }
  84. loggerFactory := config.LoggerFactory
  85. if loggerFactory == nil {
  86. loggerFactory = logging.NewDefaultLoggerFactory()
  87. }
  88. logger := loggerFactory.NewLogger("dtls")
  89. mtu := config.MTU
  90. if mtu <= 0 {
  91. mtu = defaultMTU
  92. }
  93. replayProtectionWindow := config.ReplayProtectionWindow
  94. if replayProtectionWindow <= 0 {
  95. replayProtectionWindow = defaultReplayProtectionWindow
  96. }
  97. c := &Conn{
  98. nextConn: connctx.New(nextConn),
  99. fragmentBuffer: newFragmentBuffer(),
  100. handshakeCache: newHandshakeCache(),
  101. maximumTransmissionUnit: mtu,
  102. decrypted: make(chan interface{}, 1),
  103. log: logger,
  104. readDeadline: deadline.New(),
  105. writeDeadline: deadline.New(),
  106. reading: make(chan struct{}, 1),
  107. handshakeRecv: make(chan chan struct{}),
  108. closed: closer.NewCloser(),
  109. cancelHandshaker: func() {},
  110. replayProtectionWindow: uint(replayProtectionWindow),
  111. state: State{
  112. isClient: isClient,
  113. },
  114. }
  115. c.setRemoteEpoch(0)
  116. c.setLocalEpoch(0)
  117. serverName := config.ServerName
  118. // Use host from conn address when serverName is not provided
  119. if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
  120. remoteAddr := nextConn.RemoteAddr().String()
  121. var host string
  122. host, _, err = net.SplitHostPort(remoteAddr)
  123. if err != nil {
  124. serverName = remoteAddr
  125. } else {
  126. serverName = host
  127. }
  128. }
  129. hsCfg := &handshakeConfig{
  130. localPSKCallback: config.PSK,
  131. localPSKIdentityHint: config.PSKIdentityHint,
  132. localCiscoCompatCallback: config.CiscoCompat,
  133. localCipherSuites: cipherSuites,
  134. localSignatureSchemes: signatureSchemes,
  135. extendedMasterSecret: config.ExtendedMasterSecret,
  136. localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
  137. serverName: serverName,
  138. clientAuth: config.ClientAuth,
  139. localCertificates: config.Certificates,
  140. insecureSkipVerify: config.InsecureSkipVerify,
  141. verifyPeerCertificate: config.VerifyPeerCertificate,
  142. rootCAs: config.RootCAs,
  143. clientCAs: config.ClientCAs,
  144. customCipherSuites: config.CustomCipherSuites,
  145. retransmitInterval: workerInterval,
  146. log: logger,
  147. initialEpoch: 0,
  148. keyLogWriter: config.KeyLogWriter,
  149. }
  150. var initialFlight flightVal
  151. var initialFSMState handshakeState
  152. if initialState != nil {
  153. if c.state.isClient {
  154. initialFlight = flight5
  155. } else {
  156. initialFlight = flight6
  157. }
  158. initialFSMState = handshakeFinished
  159. c.state = *initialState
  160. } else {
  161. if c.state.isClient {
  162. initialFlight = flight1
  163. } else {
  164. initialFlight = flight0
  165. }
  166. initialFSMState = handshakePreparing
  167. }
  168. // Do handshake
  169. if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
  170. return nil, err
  171. }
  172. c.log.Trace("Handshake Completed")
  173. return c, nil
  174. }
  175. // Dial connects to the given network address and establishes a DTLS connection on top.
  176. // Connection handshake will timeout using ConnectContextMaker in the Config.
  177. // If you want to specify the timeout duration, use DialWithContext() instead.
  178. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  179. ctx, cancel := config.connectContextMaker()
  180. defer cancel()
  181. return DialWithContext(ctx, network, raddr, config)
  182. }
  183. // Client establishes a DTLS connection over an existing connection.
  184. // Connection handshake will timeout using ConnectContextMaker in the Config.
  185. // If you want to specify the timeout duration, use ClientWithContext() instead.
  186. func Client(conn net.Conn, config *Config) (*Conn, error) {
  187. ctx, cancel := config.connectContextMaker()
  188. defer cancel()
  189. return ClientWithContext(ctx, conn, config)
  190. }
  191. // Server listens for incoming DTLS connections.
  192. // Connection handshake will timeout using ConnectContextMaker in the Config.
  193. // If you want to specify the timeout duration, use ServerWithContext() instead.
  194. func Server(conn net.Conn, config *Config) (*Conn, error) {
  195. ctx, cancel := config.connectContextMaker()
  196. defer cancel()
  197. return ServerWithContext(ctx, conn, config)
  198. }
  199. // DialWithContext connects to the given network address and establishes a DTLS connection on top.
  200. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  201. pConn, err := net.DialUDP(network, nil, raddr)
  202. if err != nil {
  203. return nil, err
  204. }
  205. return ClientWithContext(ctx, pConn, config)
  206. }
  207. // ClientWithContext establishes a DTLS connection over an existing connection.
  208. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  209. switch {
  210. case config == nil:
  211. return nil, errNoConfigProvided
  212. case config.PSK != nil && config.PSKIdentityHint == nil:
  213. return nil, errPSKAndIdentityMustBeSetForClient
  214. }
  215. return createConn(ctx, conn, config, true, nil)
  216. }
  217. // ServerWithContext listens for incoming DTLS connections.
  218. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  219. if config == nil {
  220. return nil, errNoConfigProvided
  221. }
  222. return createConn(ctx, conn, config, false, nil)
  223. }
  224. // Read reads data from the connection.
  225. func (c *Conn) Read(p []byte) (n int, err error) {
  226. if !c.isHandshakeCompletedSuccessfully() {
  227. return 0, errHandshakeInProgress
  228. }
  229. select {
  230. case <-c.readDeadline.Done():
  231. return 0, errDeadlineExceeded
  232. default:
  233. }
  234. for {
  235. select {
  236. case <-c.readDeadline.Done():
  237. return 0, errDeadlineExceeded
  238. case out, ok := <-c.decrypted:
  239. if !ok {
  240. return 0, io.EOF
  241. }
  242. switch val := out.(type) {
  243. case ([]byte):
  244. if len(p) < len(val) {
  245. return 0, errBufferTooSmall
  246. }
  247. copy(p, val)
  248. return len(val), nil
  249. case (error):
  250. return 0, val
  251. }
  252. }
  253. }
  254. }
  255. // Write writes len(p) bytes from p to the DTLS connection
  256. func (c *Conn) Write(p []byte) (int, error) {
  257. if c.isConnectionClosed() {
  258. return 0, ErrConnClosed
  259. }
  260. select {
  261. case <-c.writeDeadline.Done():
  262. return 0, errDeadlineExceeded
  263. default:
  264. }
  265. if !c.isHandshakeCompletedSuccessfully() {
  266. return 0, errHandshakeInProgress
  267. }
  268. return len(p), c.writePackets(c.writeDeadline, []*packet{
  269. {
  270. record: &recordlayer.RecordLayer{
  271. Header: recordlayer.Header{
  272. Epoch: c.getLocalEpoch(),
  273. Version: protocol.Version1_2,
  274. },
  275. Content: &protocol.ApplicationData{
  276. Data: p,
  277. },
  278. },
  279. shouldEncrypt: true,
  280. },
  281. })
  282. }
  283. // Close closes the connection.
  284. func (c *Conn) Close() error {
  285. err := c.close(true)
  286. c.handshakeLoopsFinished.Wait()
  287. return err
  288. }
  289. // ConnectionState returns basic DTLS details about the connection.
  290. // Note that this replaced the `Export` function of v1.
  291. func (c *Conn) ConnectionState() State {
  292. c.lock.RLock()
  293. defer c.lock.RUnlock()
  294. return *c.state.clone()
  295. }
  296. // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
  297. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
  298. c.lock.RLock()
  299. defer c.lock.RUnlock()
  300. if c.state.srtpProtectionProfile == 0 {
  301. return 0, false
  302. }
  303. return c.state.srtpProtectionProfile, true
  304. }
  305. func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
  306. c.lock.Lock()
  307. defer c.lock.Unlock()
  308. var rawPackets [][]byte
  309. for _, p := range pkts {
  310. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  311. handshakeRaw, err := p.record.Marshal()
  312. if err != nil {
  313. return err
  314. }
  315. c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
  316. srvCliStr(c.state.isClient), h.Header.Type.String(),
  317. p.record.Header.Epoch, h.Header.MessageSequence)
  318. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  319. rawHandshakePackets, err := c.processHandshakePacket(p, h)
  320. if err != nil {
  321. return err
  322. }
  323. rawPackets = append(rawPackets, rawHandshakePackets...)
  324. } else {
  325. rawPacket, err := c.processPacket(p)
  326. if err != nil {
  327. return err
  328. }
  329. rawPackets = append(rawPackets, rawPacket)
  330. }
  331. }
  332. if len(rawPackets) == 0 {
  333. return nil
  334. }
  335. compactedRawPackets := c.compactRawPackets(rawPackets)
  336. for _, compactedRawPackets := range compactedRawPackets {
  337. if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
  338. return netError(err)
  339. }
  340. }
  341. return nil
  342. }
  343. func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
  344. combinedRawPackets := make([][]byte, 0)
  345. currentCombinedRawPacket := make([]byte, 0)
  346. for _, rawPacket := range rawPackets {
  347. if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
  348. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  349. currentCombinedRawPacket = []byte{}
  350. }
  351. currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
  352. }
  353. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  354. return combinedRawPackets
  355. }
  356. func (c *Conn) processPacket(p *packet) ([]byte, error) {
  357. epoch := p.record.Header.Epoch
  358. for len(c.state.localSequenceNumber) <= int(epoch) {
  359. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  360. }
  361. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  362. if seq > recordlayer.MaxSequenceNumber {
  363. // RFC 6347 Section 4.1.0
  364. // The implementation must either abandon an association or rehandshake
  365. // prior to allowing the sequence number to wrap.
  366. return nil, errSequenceNumberOverflow
  367. }
  368. p.record.Header.SequenceNumber = seq
  369. rawPacket, err := p.record.Marshal()
  370. if err != nil {
  371. return nil, err
  372. }
  373. if p.shouldEncrypt {
  374. var err error
  375. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  376. if err != nil {
  377. return nil, err
  378. }
  379. }
  380. return rawPacket, nil
  381. }
  382. func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
  383. rawPackets := make([][]byte, 0)
  384. handshakeFragments, err := c.fragmentHandshake(h)
  385. if err != nil {
  386. return nil, err
  387. }
  388. epoch := p.record.Header.Epoch
  389. for len(c.state.localSequenceNumber) <= int(epoch) {
  390. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  391. }
  392. for _, handshakeFragment := range handshakeFragments {
  393. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  394. if seq > recordlayer.MaxSequenceNumber {
  395. return nil, errSequenceNumberOverflow
  396. }
  397. recordlayerHeader := &recordlayer.Header{
  398. Version: p.record.Header.Version,
  399. ContentType: p.record.Header.ContentType,
  400. ContentLen: uint16(len(handshakeFragment)),
  401. Epoch: p.record.Header.Epoch,
  402. SequenceNumber: seq,
  403. }
  404. recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
  405. if err != nil {
  406. return nil, err
  407. }
  408. p.record.Header = *recordlayerHeader
  409. rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
  410. if p.shouldEncrypt {
  411. var err error
  412. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  413. if err != nil {
  414. return nil, err
  415. }
  416. }
  417. rawPackets = append(rawPackets, rawPacket)
  418. }
  419. return rawPackets, nil
  420. }
  421. func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
  422. content, err := h.Message.Marshal()
  423. if err != nil {
  424. return nil, err
  425. }
  426. fragmentedHandshakes := make([][]byte, 0)
  427. contentFragments := splitBytes(content, c.maximumTransmissionUnit)
  428. if len(contentFragments) == 0 {
  429. contentFragments = [][]byte{
  430. {},
  431. }
  432. }
  433. offset := 0
  434. for _, contentFragment := range contentFragments {
  435. contentFragmentLen := len(contentFragment)
  436. headerFragment := &handshake.Header{
  437. Type: h.Header.Type,
  438. Length: h.Header.Length,
  439. MessageSequence: h.Header.MessageSequence,
  440. FragmentOffset: uint32(offset),
  441. FragmentLength: uint32(contentFragmentLen),
  442. }
  443. offset += contentFragmentLen
  444. headerFragmentRaw, err := headerFragment.Marshal()
  445. if err != nil {
  446. return nil, err
  447. }
  448. fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
  449. fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
  450. }
  451. return fragmentedHandshakes, nil
  452. }
  453. var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
  454. New: func() interface{} {
  455. b := make([]byte, inboundBufferSize)
  456. return &b
  457. },
  458. }
  459. func (c *Conn) readAndBuffer(ctx context.Context) error {
  460. bufptr := poolReadBuffer.Get().(*[]byte)
  461. defer poolReadBuffer.Put(bufptr)
  462. b := *bufptr
  463. i, err := c.nextConn.ReadContext(ctx, b)
  464. if err != nil {
  465. return netError(err)
  466. }
  467. pkts, err := recordlayer.UnpackDatagram(b[:i])
  468. if err != nil {
  469. return err
  470. }
  471. var hasHandshake bool
  472. for _, p := range pkts {
  473. hs, alert, err := c.handleIncomingPacket(p, true)
  474. if alert != nil {
  475. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  476. if err == nil {
  477. err = alertErr
  478. }
  479. }
  480. }
  481. if hs {
  482. hasHandshake = true
  483. }
  484. switch e := err.(type) {
  485. case nil:
  486. case *errAlert:
  487. if e.IsFatalOrCloseNotify() {
  488. return e
  489. }
  490. default:
  491. return e
  492. }
  493. }
  494. if hasHandshake {
  495. done := make(chan struct{})
  496. select {
  497. case c.handshakeRecv <- done:
  498. // If the other party may retransmit the flight,
  499. // we should respond even if it not a new message.
  500. <-done
  501. case <-c.fsm.Done():
  502. }
  503. }
  504. return nil
  505. }
  506. func (c *Conn) handleQueuedPackets(ctx context.Context) error {
  507. pkts := c.encryptedPackets
  508. c.encryptedPackets = nil
  509. for _, p := range pkts {
  510. _, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
  511. if alert != nil {
  512. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  513. if err == nil {
  514. err = alertErr
  515. }
  516. }
  517. }
  518. switch e := err.(type) {
  519. case nil:
  520. case *errAlert:
  521. if e.IsFatalOrCloseNotify() {
  522. return e
  523. }
  524. default:
  525. return e
  526. }
  527. }
  528. return nil
  529. }
  530. func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
  531. h := &recordlayer.Header{}
  532. if err := h.Unmarshal(buf); err != nil {
  533. // Decode error must be silently discarded
  534. // [RFC6347 Section-4.1.2.7]
  535. c.log.Debugf("discarded broken packet: %v", err)
  536. return false, nil, nil
  537. }
  538. // Validate epoch
  539. remoteEpoch := c.getRemoteEpoch()
  540. if h.Epoch > remoteEpoch {
  541. if h.Epoch > remoteEpoch+1 {
  542. c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
  543. h.Epoch, h.SequenceNumber,
  544. )
  545. return false, nil, nil
  546. }
  547. if enqueue {
  548. c.log.Debug("received packet of next epoch, queuing packet")
  549. c.encryptedPackets = append(c.encryptedPackets, buf)
  550. }
  551. return false, nil, nil
  552. }
  553. // Anti-replay protection
  554. for len(c.state.replayDetector) <= int(h.Epoch) {
  555. c.state.replayDetector = append(c.state.replayDetector,
  556. replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
  557. )
  558. }
  559. markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
  560. if !ok {
  561. c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
  562. h.Epoch, h.SequenceNumber,
  563. )
  564. return false, nil, nil
  565. }
  566. // Decrypt
  567. if h.Epoch != 0 {
  568. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  569. if enqueue {
  570. c.encryptedPackets = append(c.encryptedPackets, buf)
  571. c.log.Debug("handshake not finished, queuing packet")
  572. }
  573. return false, nil, nil
  574. }
  575. var err error
  576. buf, err = c.state.cipherSuite.Decrypt(buf)
  577. if err != nil {
  578. c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
  579. return false, nil, nil
  580. }
  581. }
  582. isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
  583. if err != nil {
  584. // Decode error must be silently discarded
  585. // [RFC6347 Section-4.1.2.7]
  586. c.log.Debugf("defragment failed: %s", err)
  587. return false, nil, nil
  588. } else if isHandshake {
  589. markPacketAsValid()
  590. for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
  591. rawHandshake := &handshake.Handshake{}
  592. if err := rawHandshake.Unmarshal(out); err != nil {
  593. c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
  594. continue
  595. }
  596. _ = c.handshakeCache.push(out, epoch, rawHandshake.Header.MessageSequence, rawHandshake.Header.Type, !c.state.isClient)
  597. }
  598. return true, nil, nil
  599. }
  600. r := &recordlayer.RecordLayer{}
  601. if err := r.Unmarshal(buf); err != nil {
  602. return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
  603. }
  604. switch content := r.Content.(type) {
  605. case *alert.Alert:
  606. c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
  607. var a *alert.Alert
  608. if content.Description == alert.CloseNotify {
  609. // Respond with a close_notify [RFC5246 Section 7.2.1]
  610. a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
  611. }
  612. markPacketAsValid()
  613. return false, a, &errAlert{content}
  614. case *protocol.ChangeCipherSpec:
  615. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  616. if enqueue {
  617. c.encryptedPackets = append(c.encryptedPackets, buf)
  618. c.log.Debugf("CipherSuite not initialized, queuing packet")
  619. }
  620. return false, nil, nil
  621. }
  622. newRemoteEpoch := h.Epoch + 1
  623. c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
  624. if c.getRemoteEpoch()+1 == newRemoteEpoch {
  625. c.setRemoteEpoch(newRemoteEpoch)
  626. markPacketAsValid()
  627. }
  628. case *protocol.ApplicationData:
  629. if h.Epoch == 0 {
  630. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
  631. }
  632. markPacketAsValid()
  633. select {
  634. case c.decrypted <- content.Data:
  635. case <-c.closed.Done():
  636. }
  637. default:
  638. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
  639. }
  640. return false, nil, nil
  641. }
  642. func (c *Conn) recvHandshake() <-chan chan struct{} {
  643. return c.handshakeRecv
  644. }
  645. func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
  646. return c.writePackets(ctx, []*packet{
  647. {
  648. record: &recordlayer.RecordLayer{
  649. Header: recordlayer.Header{
  650. Epoch: c.getLocalEpoch(),
  651. Version: protocol.Version1_2,
  652. },
  653. Content: &alert.Alert{
  654. Level: level,
  655. Description: desc,
  656. },
  657. },
  658. shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
  659. },
  660. })
  661. }
  662. func (c *Conn) setHandshakeCompletedSuccessfully() {
  663. c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
  664. }
  665. func (c *Conn) isHandshakeCompletedSuccessfully() bool {
  666. boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
  667. return boolean.bool
  668. }
  669. func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
  670. c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
  671. done := make(chan struct{})
  672. ctxRead, cancelRead := context.WithCancel(context.Background())
  673. c.cancelHandshakeReader = cancelRead
  674. cfg.onFlightState = func(f flightVal, s handshakeState) {
  675. if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
  676. c.setHandshakeCompletedSuccessfully()
  677. close(done)
  678. }
  679. }
  680. ctxHs, cancel := context.WithCancel(context.Background())
  681. c.cancelHandshaker = cancel
  682. firstErr := make(chan error, 1)
  683. c.handshakeLoopsFinished.Add(2)
  684. // Handshake routine should be live until close.
  685. // The other party may request retransmission of the last flight to cope with packet drop.
  686. go func() {
  687. defer c.handshakeLoopsFinished.Done()
  688. err := c.fsm.Run(ctxHs, c, initialState)
  689. if !errors.Is(err, context.Canceled) {
  690. select {
  691. case firstErr <- err:
  692. default:
  693. }
  694. }
  695. }()
  696. go func() {
  697. defer func() {
  698. // Escaping read loop.
  699. // It's safe to close decrypted channnel now.
  700. close(c.decrypted)
  701. // Force stop handshaker when the underlying connection is closed.
  702. cancel()
  703. }()
  704. defer c.handshakeLoopsFinished.Done()
  705. for {
  706. if err := c.readAndBuffer(ctxRead); err != nil {
  707. switch e := err.(type) {
  708. case *errAlert:
  709. if !e.IsFatalOrCloseNotify() {
  710. if c.isHandshakeCompletedSuccessfully() {
  711. // Pass the error to Read()
  712. select {
  713. case c.decrypted <- err:
  714. case <-c.closed.Done():
  715. }
  716. }
  717. continue // non-fatal alert must not stop read loop
  718. }
  719. case error:
  720. switch err {
  721. case context.DeadlineExceeded, context.Canceled, io.EOF:
  722. default:
  723. if c.isHandshakeCompletedSuccessfully() {
  724. // Keep read loop and pass the read error to Read()
  725. select {
  726. case c.decrypted <- err:
  727. case <-c.closed.Done():
  728. }
  729. continue // non-fatal alert must not stop read loop
  730. }
  731. }
  732. }
  733. select {
  734. case firstErr <- err:
  735. default:
  736. }
  737. if e, ok := err.(*errAlert); ok {
  738. if e.IsFatalOrCloseNotify() {
  739. _ = c.close(false)
  740. }
  741. }
  742. return
  743. }
  744. }
  745. }()
  746. select {
  747. case err := <-firstErr:
  748. cancelRead()
  749. cancel()
  750. return c.translateHandshakeCtxError(err)
  751. case <-ctx.Done():
  752. cancelRead()
  753. cancel()
  754. return c.translateHandshakeCtxError(ctx.Err())
  755. case <-done:
  756. return nil
  757. }
  758. }
  759. func (c *Conn) translateHandshakeCtxError(err error) error {
  760. if err == nil {
  761. return nil
  762. }
  763. if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
  764. return nil
  765. }
  766. return &HandshakeError{Err: err}
  767. }
  768. func (c *Conn) close(byUser bool) error {
  769. c.cancelHandshaker()
  770. c.cancelHandshakeReader()
  771. if c.isHandshakeCompletedSuccessfully() && byUser {
  772. // Discard error from notify() to return non-error on the first user call of Close()
  773. // even if the underlying connection is already closed.
  774. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
  775. }
  776. c.closeLock.Lock()
  777. // Don't return ErrConnClosed at the first time of the call from user.
  778. closedByUser := c.connectionClosedByUser
  779. if byUser {
  780. c.connectionClosedByUser = true
  781. }
  782. c.closed.Close()
  783. c.closeLock.Unlock()
  784. if closedByUser {
  785. return ErrConnClosed
  786. }
  787. return c.nextConn.Close()
  788. }
  789. func (c *Conn) isConnectionClosed() bool {
  790. select {
  791. case <-c.closed.Done():
  792. return true
  793. default:
  794. return false
  795. }
  796. }
  797. func (c *Conn) setLocalEpoch(epoch uint16) {
  798. c.state.localEpoch.Store(epoch)
  799. }
  800. func (c *Conn) getLocalEpoch() uint16 {
  801. return c.state.localEpoch.Load().(uint16)
  802. }
  803. func (c *Conn) setRemoteEpoch(epoch uint16) {
  804. c.state.remoteEpoch.Store(epoch)
  805. }
  806. func (c *Conn) getRemoteEpoch() uint16 {
  807. return c.state.remoteEpoch.Load().(uint16)
  808. }
  809. // LocalAddr implements net.Conn.LocalAddr
  810. func (c *Conn) LocalAddr() net.Addr {
  811. return c.nextConn.LocalAddr()
  812. }
  813. // RemoteAddr implements net.Conn.RemoteAddr
  814. func (c *Conn) RemoteAddr() net.Addr {
  815. return c.nextConn.RemoteAddr()
  816. }
  817. // SetDeadline implements net.Conn.SetDeadline
  818. func (c *Conn) SetDeadline(t time.Time) error {
  819. c.readDeadline.Set(t)
  820. return c.SetWriteDeadline(t)
  821. }
  822. // SetReadDeadline implements net.Conn.SetReadDeadline
  823. func (c *Conn) SetReadDeadline(t time.Time) error {
  824. c.readDeadline.Set(t)
  825. // Read deadline is fully managed by this layer.
  826. // Don't set read deadline to underlying connection.
  827. return nil
  828. }
  829. // SetWriteDeadline implements net.Conn.SetWriteDeadline
  830. func (c *Conn) SetWriteDeadline(t time.Time) error {
  831. c.writeDeadline.Set(t)
  832. // Write deadline is also fully managed by this layer.
  833. return nil
  834. }