badtls.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. //go:build go1.20 && !go1.21
  2. package badtls
  3. import (
  4. "crypto/cipher"
  5. "crypto/rand"
  6. "crypto/tls"
  7. "encoding/binary"
  8. "io"
  9. "net"
  10. "reflect"
  11. "sync"
  12. "sync/atomic"
  13. "unsafe"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing/common"
  16. "github.com/sagernet/sing/common/buf"
  17. "github.com/sagernet/sing/common/bufio"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. N "github.com/sagernet/sing/common/network"
  20. aTLS "github.com/sagernet/sing/common/tls"
  21. )
  22. type Conn struct {
  23. *tls.Conn
  24. writer N.ExtendedWriter
  25. isHandshakeComplete *atomic.Bool
  26. activeCall *atomic.Int32
  27. closeNotifySent *bool
  28. version *uint16
  29. rand io.Reader
  30. halfAccess *sync.Mutex
  31. halfError *error
  32. cipher cipher.AEAD
  33. explicitNonceLen int
  34. halfPtr uintptr
  35. halfSeq []byte
  36. halfScratchBuf []byte
  37. }
  38. func TryCreate(conn aTLS.Conn) aTLS.Conn {
  39. tlsConn, ok := conn.(*tls.Conn)
  40. if !ok {
  41. return conn
  42. }
  43. badConn, err := Create(tlsConn)
  44. if err != nil {
  45. log.Warn("initialize badtls: ", err)
  46. return conn
  47. }
  48. return badConn
  49. }
  50. func Create(conn *tls.Conn) (aTLS.Conn, error) {
  51. rawConn := reflect.Indirect(reflect.ValueOf(conn))
  52. rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
  53. if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
  54. return nil, E.New("badtls: invalid isHandshakeComplete")
  55. }
  56. isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
  57. if !isHandshakeComplete.Load() {
  58. return nil, E.New("handshake not finished")
  59. }
  60. rawActiveCall := rawConn.FieldByName("activeCall")
  61. if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
  62. return nil, E.New("badtls: invalid active call")
  63. }
  64. activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
  65. rawHalfConn := rawConn.FieldByName("out")
  66. if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
  67. return nil, E.New("badtls: invalid half conn")
  68. }
  69. rawVersion := rawConn.FieldByName("vers")
  70. if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
  71. return nil, E.New("badtls: invalid version")
  72. }
  73. version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
  74. rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
  75. if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
  76. return nil, E.New("badtls: invalid notify")
  77. }
  78. closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
  79. rawConfig := reflect.Indirect(rawConn.FieldByName("config"))
  80. if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct {
  81. return nil, E.New("badtls: bad config")
  82. }
  83. config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr()))
  84. randReader := config.Rand
  85. if randReader == nil {
  86. randReader = rand.Reader
  87. }
  88. rawHalfMutex := rawHalfConn.FieldByName("Mutex")
  89. if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
  90. return nil, E.New("badtls: invalid half mutex")
  91. }
  92. halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
  93. rawHalfError := rawHalfConn.FieldByName("err")
  94. if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface {
  95. return nil, E.New("badtls: invalid half error")
  96. }
  97. halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr()))
  98. rawHalfCipherInterface := rawHalfConn.FieldByName("cipher")
  99. if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface {
  100. return nil, E.New("badtls: invalid cipher interface")
  101. }
  102. rawHalfCipher := rawHalfCipherInterface.Elem()
  103. aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD)
  104. if !loaded {
  105. return nil, E.New("badtls: invalid AEAD cipher")
  106. }
  107. var explicitNonceLen int
  108. switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName {
  109. case "tls.prefixNonceAEAD":
  110. explicitNonceLen = aeadCipher.NonceSize()
  111. case "tls.xorNonceAEAD":
  112. default:
  113. return nil, E.New("badtls: unknown cipher type: ", cipherName)
  114. }
  115. rawHalfSeq := rawHalfConn.FieldByName("seq")
  116. if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array {
  117. return nil, E.New("badtls: invalid seq")
  118. }
  119. halfSeq := rawHalfSeq.Bytes()
  120. rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf")
  121. if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array {
  122. return nil, E.New("badtls: invalid scratchBuf")
  123. }
  124. halfScratchBuf := rawHalfScratchBuf.Bytes()
  125. return &Conn{
  126. Conn: conn,
  127. writer: bufio.NewExtendedWriter(conn.NetConn()),
  128. isHandshakeComplete: isHandshakeComplete,
  129. activeCall: activeCall,
  130. closeNotifySent: closeNotifySent,
  131. version: version,
  132. halfAccess: halfAccess,
  133. halfError: halfError,
  134. cipher: aeadCipher,
  135. explicitNonceLen: explicitNonceLen,
  136. rand: randReader,
  137. halfPtr: rawHalfConn.UnsafeAddr(),
  138. halfSeq: halfSeq,
  139. halfScratchBuf: halfScratchBuf,
  140. }, nil
  141. }
  142. func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
  143. if buffer.Len() > maxPlaintext {
  144. defer buffer.Release()
  145. return common.Error(c.Write(buffer.Bytes()))
  146. }
  147. for {
  148. x := c.activeCall.Load()
  149. if x&1 != 0 {
  150. return net.ErrClosed
  151. }
  152. if c.activeCall.CompareAndSwap(x, x+2) {
  153. break
  154. }
  155. }
  156. defer c.activeCall.Add(-2)
  157. c.halfAccess.Lock()
  158. defer c.halfAccess.Unlock()
  159. if err := *c.halfError; err != nil {
  160. return err
  161. }
  162. if *c.closeNotifySent {
  163. return errShutdown
  164. }
  165. dataLen := buffer.Len()
  166. dataBytes := buffer.Bytes()
  167. outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen)
  168. outBuf[0] = 23
  169. version := *c.version
  170. if version == 0 {
  171. version = tls.VersionTLS10
  172. } else if version == tls.VersionTLS13 {
  173. version = tls.VersionTLS12
  174. }
  175. binary.BigEndian.PutUint16(outBuf[1:], version)
  176. var nonce []byte
  177. if c.explicitNonceLen > 0 {
  178. nonce = outBuf[5 : 5+c.explicitNonceLen]
  179. if c.explicitNonceLen < 16 {
  180. copy(nonce, c.halfSeq)
  181. } else {
  182. if _, err := io.ReadFull(c.rand, nonce); err != nil {
  183. return err
  184. }
  185. }
  186. }
  187. if len(nonce) == 0 {
  188. nonce = c.halfSeq
  189. }
  190. if *c.version == tls.VersionTLS13 {
  191. buffer.FreeBytes()[0] = 23
  192. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead()))
  193. c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen])
  194. buffer.Extend(1 + c.cipher.Overhead())
  195. } else {
  196. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen))
  197. additionalData := append(c.halfScratchBuf[:0], c.halfSeq...)
  198. additionalData = append(additionalData, outBuf[:recordHeaderLen]...)
  199. c.cipher.Seal(outBuf, nonce, dataBytes, additionalData)
  200. buffer.Extend(c.cipher.Overhead())
  201. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
  202. }
  203. incSeq(c.halfPtr)
  204. log.Trace("badtls write ", buffer.Len())
  205. return c.writer.WriteBuffer(buffer)
  206. }
  207. func (c *Conn) FrontHeadroom() int {
  208. return recordHeaderLen + c.explicitNonceLen
  209. }
  210. func (c *Conn) RearHeadroom() int {
  211. return 1 + c.cipher.Overhead()
  212. }
  213. func (c *Conn) WriterMTU() int {
  214. return maxPlaintext
  215. }
  216. func (c *Conn) Upstream() any {
  217. return c.Conn
  218. }
  219. func (c *Conn) UpstreamWriter() any {
  220. return c.NetConn()
  221. }