ktls_linux.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. //go:build linux && go1.25 && badlinkname
  2. package ktls
  3. import (
  4. "crypto/tls"
  5. "errors"
  6. "io"
  7. "os"
  8. "strings"
  9. "sync"
  10. "syscall"
  11. "unsafe"
  12. "github.com/sagernet/sing-box/common/badversion"
  13. "github.com/sagernet/sing/common/control"
  14. E "github.com/sagernet/sing/common/exceptions"
  15. "github.com/sagernet/sing/common/shell"
  16. "golang.org/x/sys/unix"
  17. )
  18. // mod from https://gitlab.com/go-extension/tls
  19. const (
  20. TLS_TX = 1
  21. TLS_RX = 2
  22. TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now)
  23. TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only
  24. TLS_SET_RECORD_TYPE = 1
  25. TLS_GET_RECORD_TYPE = 2
  26. )
  27. type Support struct {
  28. TLS, TLS_RX bool
  29. TLS_Version13, TLS_Version13_RX bool
  30. TLS_TX_ZEROCOPY bool
  31. TLS_RX_NOPADDING bool
  32. TLS_AES_256_GCM bool
  33. TLS_AES_128_CCM bool
  34. TLS_CHACHA20_POLY1305 bool
  35. TLS_SM4 bool
  36. TLS_ARIA_GCM bool
  37. TLS_Version13_KeyUpdate bool
  38. }
  39. var KernelSupport = sync.OnceValues(func() (*Support, error) {
  40. var uname unix.Utsname
  41. err := unix.Uname(&uname)
  42. if err != nil {
  43. return nil, err
  44. }
  45. kernelVersion := badversion.Parse(strings.Trim(string(uname.Release[:]), "\x00"))
  46. if err != nil {
  47. return nil, err
  48. }
  49. var support Support
  50. switch {
  51. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 14}):
  52. support.TLS_Version13_KeyUpdate = true
  53. fallthrough
  54. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 1}):
  55. support.TLS_ARIA_GCM = true
  56. fallthrough
  57. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6}):
  58. support.TLS_Version13_RX = true
  59. support.TLS_RX_NOPADDING = true
  60. fallthrough
  61. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 19}):
  62. support.TLS_TX_ZEROCOPY = true
  63. fallthrough
  64. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 16}):
  65. support.TLS_SM4 = true
  66. fallthrough
  67. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 11}):
  68. support.TLS_CHACHA20_POLY1305 = true
  69. fallthrough
  70. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 2}):
  71. support.TLS_AES_128_CCM = true
  72. fallthrough
  73. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 1}):
  74. support.TLS_AES_256_GCM = true
  75. support.TLS_Version13 = true
  76. fallthrough
  77. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 17}):
  78. support.TLS_RX = true
  79. fallthrough
  80. case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 13}):
  81. support.TLS = true
  82. }
  83. if support.TLS && support.TLS_Version13 {
  84. _, err := os.Stat("/sys/module/tls")
  85. if err != nil {
  86. if os.Getuid() == 0 {
  87. output, err := shell.Exec("modprobe", "tls").Read()
  88. if err != nil {
  89. return nil, E.Extend(E.Cause(err, "modprobe tls"), output)
  90. }
  91. } else {
  92. return nil, E.New("ktls: kernel TLS module not loaded")
  93. }
  94. }
  95. }
  96. return &support, nil
  97. })
  98. func Load() error {
  99. support, err := KernelSupport()
  100. if err != nil {
  101. return E.Cause(err, "ktls: check availability")
  102. }
  103. if !support.TLS || !support.TLS_Version13 {
  104. return E.New("ktls: kernel does not support TLS 1.3")
  105. }
  106. return nil
  107. }
  108. func (c *Conn) setupKernel(txOffload, rxOffload bool) error {
  109. if !txOffload && !rxOffload {
  110. return os.ErrInvalid
  111. }
  112. support, err := KernelSupport()
  113. if err != nil {
  114. return E.Cause(err, "check availability")
  115. }
  116. if !support.TLS || !support.TLS_Version13 {
  117. return E.New("kernel does not support TLS 1.3")
  118. }
  119. c.rawConn.Out.Lock()
  120. defer c.rawConn.Out.Unlock()
  121. err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  122. return syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls")
  123. })
  124. if err != nil {
  125. return os.NewSyscallError("setsockopt", err)
  126. }
  127. if txOffload {
  128. txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
  129. if txCrypto == nil {
  130. return E.New("unsupported cipher suite")
  131. }
  132. err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  133. return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
  134. })
  135. if err != nil {
  136. return err
  137. }
  138. if support.TLS_TX_ZEROCOPY {
  139. err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  140. return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1)
  141. })
  142. if err != nil {
  143. return err
  144. }
  145. }
  146. c.kernelTx = true
  147. c.logger.DebugContext(c.ctx, "ktls: kernel TLS TX enabled")
  148. }
  149. if rxOffload {
  150. rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
  151. if rxCrypto == nil {
  152. return E.New("unsupported cipher suite")
  153. }
  154. err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  155. return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
  156. })
  157. if err != nil {
  158. return err
  159. }
  160. if *c.rawConn.Vers >= tls.VersionTLS13 && support.TLS_RX_NOPADDING {
  161. err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  162. return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1)
  163. })
  164. if err != nil {
  165. return err
  166. }
  167. }
  168. c.kernelRx = true
  169. c.logger.DebugContext(c.ctx, "ktls: kernel TLS RX enabled")
  170. }
  171. return nil
  172. }
  173. func (c *Conn) resetupTX() (func() error, error) {
  174. if !c.kernelTx {
  175. return nil, nil
  176. }
  177. support, err := KernelSupport()
  178. if err != nil {
  179. return nil, err
  180. }
  181. if !support.TLS_Version13_KeyUpdate {
  182. return nil, errors.New("ktls: kernel does not support rekey")
  183. }
  184. txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
  185. if txCrypto == nil {
  186. return nil, errors.New("ktls: set kernelCipher on unsupported tls session")
  187. }
  188. return func() error {
  189. return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  190. return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
  191. })
  192. }, nil
  193. }
  194. func (c *Conn) resetupRX() error {
  195. if !c.kernelRx {
  196. return nil
  197. }
  198. support, err := KernelSupport()
  199. if err != nil {
  200. return err
  201. }
  202. if !support.TLS_Version13_KeyUpdate {
  203. return errors.New("ktls: kernel does not support rekey")
  204. }
  205. rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
  206. if rxCrypto == nil {
  207. return errors.New("ktls: set kernelCipher on unsupported tls session")
  208. }
  209. return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
  210. return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
  211. })
  212. }
  213. func (c *Conn) readKernelRecord() (uint8, []byte, error) {
  214. if c.rawConn.RawInput.Len() < maxPlaintext {
  215. c.rawConn.RawInput.Grow(maxPlaintext - c.rawConn.RawInput.Len())
  216. }
  217. data := c.rawConn.RawInput.Bytes()[:maxPlaintext]
  218. // cmsg for record type
  219. buffer := make([]byte, unix.CmsgSpace(1))
  220. cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
  221. cmsg.SetLen(unix.CmsgLen(1))
  222. var iov unix.Iovec
  223. iov.Base = &data[0]
  224. iov.SetLen(len(data))
  225. var msg unix.Msghdr
  226. msg.Control = &buffer[0]
  227. msg.Controllen = cmsg.Len
  228. msg.Iov = &iov
  229. msg.Iovlen = 1
  230. var n int
  231. var err error
  232. er := c.rawSyscallConn.Read(func(fd uintptr) bool {
  233. n, err = recvmsg(int(fd), &msg, 0)
  234. return err != unix.EAGAIN || c.pendingRxSplice
  235. })
  236. if er != nil {
  237. return 0, nil, er
  238. }
  239. switch err {
  240. case nil:
  241. case syscall.EINVAL, syscall.EAGAIN:
  242. return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion))
  243. case syscall.EMSGSIZE:
  244. return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
  245. case syscall.EBADMSG:
  246. return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecryptError))
  247. default:
  248. return 0, nil, err
  249. }
  250. if n <= 0 {
  251. return 0, nil, c.rawConn.In.SetErrorLocked(io.EOF)
  252. }
  253. if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE {
  254. typ := buffer[unix.CmsgLen(0)]
  255. return typ, data[:n], nil
  256. }
  257. return recordTypeApplicationData, data[:n], nil
  258. }
  259. func (c *Conn) writeKernelRecord(typ uint16, data []byte) (int, error) {
  260. if typ == recordTypeApplicationData {
  261. return c.conn.Write(data)
  262. }
  263. // cmsg for record type
  264. buffer := make([]byte, unix.CmsgSpace(1))
  265. cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
  266. cmsg.SetLen(unix.CmsgLen(1))
  267. buffer[unix.CmsgLen(0)] = byte(typ)
  268. cmsg.Level = unix.SOL_TLS
  269. cmsg.Type = TLS_SET_RECORD_TYPE
  270. var iov unix.Iovec
  271. iov.Base = &data[0]
  272. iov.SetLen(len(data))
  273. var msg unix.Msghdr
  274. msg.Control = &buffer[0]
  275. msg.Controllen = cmsg.Len
  276. msg.Iov = &iov
  277. msg.Iovlen = 1
  278. var n int
  279. var err error
  280. ew := c.rawSyscallConn.Write(func(fd uintptr) bool {
  281. n, err = sendmsg(int(fd), &msg, 0)
  282. return err != unix.EAGAIN
  283. })
  284. if ew != nil {
  285. return 0, ew
  286. }
  287. return n, err
  288. }
  289. //go:linkname recvmsg golang.org/x/sys/unix.recvmsg
  290. func recvmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
  291. //go:linkname sendmsg golang.org/x/sys/unix.sendmsg
  292. func sendmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)