ktls_key_update.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux && go1.25 && badlinkname
  5. package ktls
  6. import (
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "os"
  12. )
  13. // handlePostHandshakeMessage processes a handshake message arrived after the
  14. // handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
  15. func (c *Conn) handlePostHandshakeMessage() error {
  16. if *c.rawConn.Vers != tls.VersionTLS13 {
  17. return errors.New("ktls: kernel does not support TLS 1.2 renegotiation")
  18. }
  19. msg, err := c.readHandshake(nil)
  20. if err != nil {
  21. return err
  22. }
  23. //c.retryCount++
  24. //if c.retryCount > maxUselessRecords {
  25. // c.sendAlert(alertUnexpectedMessage)
  26. // return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
  27. //}
  28. switch msg := msg.(type) {
  29. case *newSessionTicketMsgTLS13:
  30. // return errors.New("ktls: received new session ticket")
  31. return nil
  32. case *keyUpdateMsg:
  33. return c.handleKeyUpdate(msg)
  34. }
  35. // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
  36. // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
  37. // unexpected_message alert here doesn't provide it with enough information to distinguish
  38. // this condition from other unexpected messages. This is probably fine.
  39. c.sendAlert(alertUnexpectedMessage)
  40. return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
  41. }
  42. func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
  43. //if c.quic != nil {
  44. // c.sendAlert(alertUnexpectedMessage)
  45. // return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
  46. //}
  47. cipherSuite := cipherSuiteTLS13ByID(*c.rawConn.CipherSuite)
  48. if cipherSuite == nil {
  49. return c.rawConn.In.SetErrorLocked(c.sendAlert(alertInternalError))
  50. }
  51. newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.In.TrafficSecret)
  52. c.rawConn.In.SetTrafficSecret(cipherSuite, 0 /*tls.QUICEncryptionLevelInitial*/, newSecret)
  53. err := c.resetupRX()
  54. if err != nil {
  55. c.sendAlert(alertInternalError)
  56. return c.rawConn.In.SetErrorLocked(fmt.Errorf("ktls: resetupRX failed: %w", err))
  57. }
  58. if keyUpdate.updateRequested {
  59. c.rawConn.Out.Lock()
  60. defer c.rawConn.Out.Unlock()
  61. resetup, err := c.resetupTX()
  62. if err != nil {
  63. c.sendAlertLocked(alertInternalError)
  64. return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
  65. }
  66. msg := &keyUpdateMsg{}
  67. msgBytes, err := msg.marshal()
  68. if err != nil {
  69. return err
  70. }
  71. _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
  72. if err != nil {
  73. // Surface the error at the next write.
  74. c.rawConn.Out.SetErrorLocked(err)
  75. return nil
  76. }
  77. newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.Out.TrafficSecret)
  78. c.rawConn.Out.SetTrafficSecret(cipherSuite, 0 /*QUICEncryptionLevelInitial*/, newSecret)
  79. err = resetup()
  80. if err != nil {
  81. return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
  82. }
  83. }
  84. return nil
  85. }
  86. func (c *Conn) readHandshakeBytes(n int) error {
  87. //if c.quic != nil {
  88. // return c.quicReadHandshakeBytes(n)
  89. //}
  90. for c.rawConn.Hand.Len() < n {
  91. if err := c.readRecord(); err != nil {
  92. return err
  93. }
  94. }
  95. return nil
  96. }
  97. func (c *Conn) readHandshake(transcript io.Writer) (any, error) {
  98. if err := c.readHandshakeBytes(4); err != nil {
  99. return nil, err
  100. }
  101. data := c.rawConn.Hand.Bytes()
  102. maxHandshakeSize := maxHandshake
  103. // hasVers indicates we're past the first message, forcing someone trying to
  104. // make us just allocate a large buffer to at least do the initial part of
  105. // the handshake first.
  106. //if c.haveVers && data[0] == typeCertificate {
  107. // Since certificate messages are likely to be the only messages that
  108. // can be larger than maxHandshake, we use a special limit for just
  109. // those messages.
  110. //maxHandshakeSize = maxHandshakeCertificateMsg
  111. //}
  112. n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  113. if n > maxHandshakeSize {
  114. c.sendAlertLocked(alertInternalError)
  115. return nil, c.rawConn.In.SetErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
  116. }
  117. if err := c.readHandshakeBytes(4 + n); err != nil {
  118. return nil, err
  119. }
  120. data = c.rawConn.Hand.Next(4 + n)
  121. return c.unmarshalHandshakeMessage(data, transcript)
  122. }
  123. func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript io.Writer) (any, error) {
  124. var m handshakeMessage
  125. switch data[0] {
  126. case typeNewSessionTicket:
  127. if *c.rawConn.Vers == tls.VersionTLS13 {
  128. m = new(newSessionTicketMsgTLS13)
  129. } else {
  130. return nil, os.ErrInvalid
  131. }
  132. case typeKeyUpdate:
  133. m = new(keyUpdateMsg)
  134. default:
  135. return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
  136. }
  137. // The handshake message unmarshalers
  138. // expect to be able to keep references to data,
  139. // so pass in a fresh copy that won't be overwritten.
  140. data = append([]byte(nil), data...)
  141. if !m.unmarshal(data) {
  142. return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
  143. }
  144. if transcript != nil {
  145. transcript.Write(data)
  146. }
  147. return m, nil
  148. }