ktls.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. //go:build linux && go1.25 && badlinkname
  2. package ktls
  3. import (
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "io"
  9. "net"
  10. "os"
  11. "syscall"
  12. "github.com/sagernet/sing-box/common/badtls"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. N "github.com/sagernet/sing/common/network"
  16. aTLS "github.com/sagernet/sing/common/tls"
  17. "golang.org/x/sys/unix"
  18. )
  19. type Conn struct {
  20. aTLS.Conn
  21. ctx context.Context
  22. logger logger.ContextLogger
  23. conn net.Conn
  24. rawConn *badtls.RawConn
  25. syscallConn syscall.Conn
  26. rawSyscallConn syscall.RawConn
  27. readWaitOptions N.ReadWaitOptions
  28. kernelTx bool
  29. kernelRx bool
  30. pendingRxSplice bool
  31. }
  32. func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
  33. err := Load()
  34. if err != nil {
  35. return nil, err
  36. }
  37. syscallConn, isSyscallConn := N.CastReader[interface {
  38. io.Reader
  39. syscall.Conn
  40. }](conn.NetConn())
  41. if !isSyscallConn {
  42. return nil, os.ErrInvalid
  43. }
  44. rawSyscallConn, err := syscallConn.SyscallConn()
  45. if err != nil {
  46. return nil, err
  47. }
  48. rawConn, err := badtls.NewRawConn(conn)
  49. if err != nil {
  50. return nil, err
  51. }
  52. if *rawConn.Vers != tls.VersionTLS13 {
  53. return nil, os.ErrInvalid
  54. }
  55. for rawConn.RawInput.Len() > 0 {
  56. err = rawConn.ReadRecord()
  57. if err != nil {
  58. return nil, err
  59. }
  60. for rawConn.Hand.Len() > 0 {
  61. err = rawConn.HandlePostHandshakeMessage()
  62. if err != nil {
  63. return nil, E.Cause(err, "handle post-handshake messages")
  64. }
  65. }
  66. }
  67. kConn := &Conn{
  68. Conn: conn,
  69. ctx: ctx,
  70. logger: logger,
  71. conn: conn.NetConn(),
  72. rawConn: rawConn,
  73. syscallConn: syscallConn,
  74. rawSyscallConn: rawSyscallConn,
  75. }
  76. err = kConn.setupKernel(txOffload, rxOffload)
  77. if err != nil {
  78. return nil, err
  79. }
  80. return kConn, nil
  81. }
  82. func (c *Conn) Upstream() any {
  83. return c.Conn
  84. }
  85. func (c *Conn) SyscallConnForRead() syscall.RawConn {
  86. if !c.kernelRx {
  87. return nil
  88. }
  89. if !*c.rawConn.IsClient {
  90. c.logger.WarnContext(c.ctx, "ktls: RX splice is unavailable on the server size, since it will cause an unknown failure")
  91. return nil
  92. }
  93. c.logger.DebugContext(c.ctx, "ktls: RX splice requested")
  94. return c.rawSyscallConn
  95. }
  96. func (c *Conn) HandleSyscallReadError(inputErr error) ([]byte, error) {
  97. if errors.Is(inputErr, unix.EINVAL) {
  98. c.pendingRxSplice = true
  99. err := c.readRecord()
  100. if err != nil {
  101. return nil, E.Cause(err, "ktls: handle non-application-data record")
  102. }
  103. var input bytes.Buffer
  104. if c.rawConn.Input.Len() > 0 {
  105. _, err = c.rawConn.Input.WriteTo(&input)
  106. if err != nil {
  107. return nil, err
  108. }
  109. }
  110. return input.Bytes(), nil
  111. } else if errors.Is(inputErr, unix.EBADMSG) {
  112. return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertBadRecordMAC))
  113. } else {
  114. return nil, E.Cause(inputErr, "ktls: unexpected errno")
  115. }
  116. }
  117. func (c *Conn) SyscallConnForWrite() syscall.RawConn {
  118. if !c.kernelTx {
  119. return nil
  120. }
  121. c.logger.DebugContext(c.ctx, "ktls: TX splice requested")
  122. return c.rawSyscallConn
  123. }