ktls.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. //go:build linux && go1.25 && !without_badtls
  2. package ktls
  3. import (
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "io"
  9. "net"
  10. "os"
  11. "syscall"
  12. "github.com/sagernet/sing-box/common/badtls"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. N "github.com/sagernet/sing/common/network"
  16. aTLS "github.com/sagernet/sing/common/tls"
  17. "golang.org/x/sys/unix"
  18. )
  19. type Conn struct {
  20. aTLS.Conn
  21. ctx context.Context
  22. logger logger.ContextLogger
  23. conn net.Conn
  24. rawConn *badtls.RawConn
  25. syscallConn syscall.Conn
  26. rawSyscallConn syscall.RawConn
  27. readWaitOptions N.ReadWaitOptions
  28. kernelTx bool
  29. kernelRx bool
  30. }
  31. func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
  32. err := Load()
  33. if err != nil {
  34. return nil, err
  35. }
  36. syscallConn, isSyscallConn := N.CastReader[interface {
  37. io.Reader
  38. syscall.Conn
  39. }](conn.NetConn())
  40. if !isSyscallConn {
  41. return nil, os.ErrInvalid
  42. }
  43. rawSyscallConn, err := syscallConn.SyscallConn()
  44. if err != nil {
  45. return nil, err
  46. }
  47. rawConn, err := badtls.NewRawConn(conn)
  48. if err != nil {
  49. return nil, err
  50. }
  51. if *rawConn.Vers != tls.VersionTLS13 {
  52. return nil, os.ErrInvalid
  53. }
  54. for rawConn.RawInput.Len() > 0 {
  55. err = rawConn.ReadRecord()
  56. if err != nil {
  57. return nil, err
  58. }
  59. for rawConn.Hand.Len() > 0 {
  60. err = rawConn.HandlePostHandshakeMessage()
  61. if err != nil {
  62. return nil, E.Cause(err, "handle post-handshake messages")
  63. }
  64. }
  65. }
  66. kConn := &Conn{
  67. Conn: conn,
  68. ctx: ctx,
  69. logger: logger,
  70. conn: conn.NetConn(),
  71. rawConn: rawConn,
  72. syscallConn: syscallConn,
  73. rawSyscallConn: rawSyscallConn,
  74. }
  75. err = kConn.setupKernel(txOffload, rxOffload)
  76. if err != nil {
  77. return nil, err
  78. }
  79. return kConn, nil
  80. }
  81. func (c *Conn) Upstream() any {
  82. return c.Conn
  83. }
  84. func (c *Conn) SyscallConnForRead() syscall.RawConn {
  85. if !c.kernelRx {
  86. return nil
  87. }
  88. if !*c.rawConn.IsClient {
  89. c.logger.WarnContext(c.ctx, "ktls: RX splice is unavailable on the server size, since it will cause an unknown failure")
  90. return nil
  91. }
  92. c.logger.DebugContext(c.ctx, "ktls: RX splice requested")
  93. return c.rawSyscallConn
  94. }
  95. func (c *Conn) HandleSyscallReadError(inputErr error) ([]byte, error) {
  96. if errors.Is(inputErr, unix.EINVAL) {
  97. err := c.readRecord()
  98. if err != nil {
  99. return nil, E.Cause(err, "ktls: handle non-application-data record")
  100. }
  101. var input bytes.Buffer
  102. if c.rawConn.Input.Len() > 0 {
  103. _, err = c.rawConn.Input.WriteTo(&input)
  104. if err != nil {
  105. return nil, err
  106. }
  107. }
  108. return input.Bytes(), nil
  109. } else if errors.Is(inputErr, unix.EBADMSG) {
  110. return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertBadRecordMAC))
  111. } else {
  112. return nil, E.Cause(inputErr, "ktls: unexpected errno")
  113. }
  114. }
  115. func (c *Conn) SyscallConnForWrite() syscall.RawConn {
  116. if !c.kernelTx {
  117. return nil
  118. }
  119. c.logger.DebugContext(c.ctx, "ktls: TX splice requested")
  120. return c.rawSyscallConn
  121. }