badtls.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. //go:build go1.19 && !go1.20
  2. package badtls
  3. import (
  4. "crypto/cipher"
  5. "crypto/rand"
  6. "crypto/tls"
  7. "encoding/binary"
  8. "io"
  9. "net"
  10. "reflect"
  11. "sync"
  12. "sync/atomic"
  13. "unsafe"
  14. "github.com/sagernet/sing/common"
  15. "github.com/sagernet/sing/common/buf"
  16. "github.com/sagernet/sing/common/bufio"
  17. E "github.com/sagernet/sing/common/exceptions"
  18. N "github.com/sagernet/sing/common/network"
  19. )
  20. type Conn struct {
  21. *tls.Conn
  22. writer N.ExtendedWriter
  23. activeCall *int32
  24. closeNotifySent *bool
  25. version *uint16
  26. rand io.Reader
  27. halfAccess *sync.Mutex
  28. halfError *error
  29. cipher cipher.AEAD
  30. explicitNonceLen int
  31. halfPtr uintptr
  32. halfSeq []byte
  33. halfScratchBuf []byte
  34. }
  35. func Create(conn *tls.Conn) (TLSConn, error) {
  36. if !handshakeComplete(conn) {
  37. return nil, E.New("handshake not finished")
  38. }
  39. rawConn := reflect.Indirect(reflect.ValueOf(conn))
  40. rawActiveCall := rawConn.FieldByName("activeCall")
  41. if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 {
  42. return nil, E.New("badtls: invalid active call")
  43. }
  44. activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
  45. rawHalfConn := rawConn.FieldByName("out")
  46. if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
  47. return nil, E.New("badtls: invalid half conn")
  48. }
  49. rawVersion := rawConn.FieldByName("vers")
  50. if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
  51. return nil, E.New("badtls: invalid version")
  52. }
  53. version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
  54. rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
  55. if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
  56. return nil, E.New("badtls: invalid notify")
  57. }
  58. closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
  59. rawConfig := reflect.Indirect(rawConn.FieldByName("config"))
  60. if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct {
  61. return nil, E.New("badtls: bad config")
  62. }
  63. config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr()))
  64. randReader := config.Rand
  65. if randReader == nil {
  66. randReader = rand.Reader
  67. }
  68. rawHalfMutex := rawHalfConn.FieldByName("Mutex")
  69. if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
  70. return nil, E.New("badtls: invalid half mutex")
  71. }
  72. halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
  73. rawHalfError := rawHalfConn.FieldByName("err")
  74. if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface {
  75. return nil, E.New("badtls: invalid half error")
  76. }
  77. halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr()))
  78. rawHalfCipherInterface := rawHalfConn.FieldByName("cipher")
  79. if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface {
  80. return nil, E.New("badtls: invalid cipher interface")
  81. }
  82. rawHalfCipher := rawHalfCipherInterface.Elem()
  83. aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD)
  84. if !loaded {
  85. return nil, E.New("badtls: invalid AEAD cipher")
  86. }
  87. var explicitNonceLen int
  88. switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName {
  89. case "tls.prefixNonceAEAD":
  90. explicitNonceLen = aeadCipher.NonceSize()
  91. case "tls.xorNonceAEAD":
  92. default:
  93. return nil, E.New("badtls: unknown cipher type: ", cipherName)
  94. }
  95. rawHalfSeq := rawHalfConn.FieldByName("seq")
  96. if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array {
  97. return nil, E.New("badtls: invalid seq")
  98. }
  99. halfSeq := rawHalfSeq.Bytes()
  100. rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf")
  101. if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array {
  102. return nil, E.New("badtls: invalid scratchBuf")
  103. }
  104. halfScratchBuf := rawHalfScratchBuf.Bytes()
  105. return &Conn{
  106. Conn: conn,
  107. writer: bufio.NewExtendedWriter(conn.NetConn()),
  108. activeCall: activeCall,
  109. closeNotifySent: closeNotifySent,
  110. version: version,
  111. halfAccess: halfAccess,
  112. halfError: halfError,
  113. cipher: aeadCipher,
  114. explicitNonceLen: explicitNonceLen,
  115. rand: randReader,
  116. halfPtr: rawHalfConn.UnsafeAddr(),
  117. halfSeq: halfSeq,
  118. halfScratchBuf: halfScratchBuf,
  119. }, nil
  120. }
  121. func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
  122. if buffer.Len() > maxPlaintext {
  123. defer buffer.Release()
  124. return common.Error(c.Write(buffer.Bytes()))
  125. }
  126. for {
  127. x := atomic.LoadInt32(c.activeCall)
  128. if x&1 != 0 {
  129. return net.ErrClosed
  130. }
  131. if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) {
  132. break
  133. }
  134. }
  135. defer atomic.AddInt32(c.activeCall, -2)
  136. c.halfAccess.Lock()
  137. defer c.halfAccess.Unlock()
  138. if err := *c.halfError; err != nil {
  139. return err
  140. }
  141. if *c.closeNotifySent {
  142. return errShutdown
  143. }
  144. dataLen := buffer.Len()
  145. dataBytes := buffer.Bytes()
  146. outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen)
  147. outBuf[0] = 23
  148. version := *c.version
  149. if version == 0 {
  150. version = tls.VersionTLS10
  151. } else if version == tls.VersionTLS13 {
  152. version = tls.VersionTLS12
  153. }
  154. binary.BigEndian.PutUint16(outBuf[1:], version)
  155. var nonce []byte
  156. if c.explicitNonceLen > 0 {
  157. nonce = outBuf[5 : 5+c.explicitNonceLen]
  158. if c.explicitNonceLen < 16 {
  159. copy(nonce, c.halfSeq)
  160. } else {
  161. if _, err := io.ReadFull(c.rand, nonce); err != nil {
  162. return err
  163. }
  164. }
  165. }
  166. if len(nonce) == 0 {
  167. nonce = c.halfSeq
  168. }
  169. if *c.version == tls.VersionTLS13 {
  170. buffer.FreeBytes()[0] = 23
  171. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead()))
  172. c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen])
  173. buffer.Extend(1 + c.cipher.Overhead())
  174. } else {
  175. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen))
  176. additionalData := append(c.halfScratchBuf[:0], c.halfSeq...)
  177. additionalData = append(additionalData, outBuf[:recordHeaderLen]...)
  178. c.cipher.Seal(outBuf, nonce, dataBytes, additionalData)
  179. buffer.Extend(c.cipher.Overhead())
  180. binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
  181. }
  182. incSeq(c.halfPtr)
  183. return c.writer.WriteBuffer(buffer)
  184. }
  185. func (c *Conn) FrontHeadroom() int {
  186. return recordHeaderLen + c.explicitNonceLen
  187. }
  188. func (c *Conn) RearHeadroom() int {
  189. return 1 + c.cipher.Overhead()
  190. }
  191. func (c *Conn) WriterMTU() int {
  192. return maxPlaintext
  193. }
  194. func (c *Conn) Upstream() any {
  195. return c.Conn
  196. }
  197. func (c *Conn) UpstreamWriter() any {
  198. return c.NetConn()
  199. }