vision.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package vless
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "crypto/tls"
  6. "io"
  7. "math/big"
  8. "net"
  9. "reflect"
  10. "time"
  11. "unsafe"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing/common"
  14. "github.com/sagernet/sing/common/buf"
  15. "github.com/sagernet/sing/common/bufio"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. "github.com/sagernet/sing/common/logger"
  18. N "github.com/sagernet/sing/common/network"
  19. )
  20. var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
  21. func init() {
  22. tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
  23. tlsConn, loaded := conn.(*tls.Conn)
  24. if !loaded {
  25. return
  26. }
  27. return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
  28. })
  29. }
  30. const xrayChunkSize = 8192
  31. type VisionConn struct {
  32. net.Conn
  33. reader *bufio.ChunkReader
  34. writer N.VectorisedWriter
  35. input *bytes.Reader
  36. rawInput *bytes.Buffer
  37. netConn net.Conn
  38. logger logger.Logger
  39. userUUID [16]byte
  40. isTLS bool
  41. numberOfPacketToFilter int
  42. isTLS12orAbove bool
  43. remainingServerHello int32
  44. cipher uint16
  45. enableXTLS bool
  46. isPadding bool
  47. directWrite bool
  48. writeUUID bool
  49. withinPaddingBuffers bool
  50. remainingContent int
  51. remainingPadding int
  52. currentCommand int
  53. directRead bool
  54. remainingReader io.Reader
  55. }
  56. func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
  57. var (
  58. loaded bool
  59. reflectType reflect.Type
  60. reflectPointer uintptr
  61. netConn net.Conn
  62. )
  63. for _, tlsCreator := range tlsRegistry {
  64. loaded, netConn, reflectType, reflectPointer = tlsCreator(conn)
  65. if loaded {
  66. break
  67. }
  68. }
  69. if !loaded {
  70. return nil, C.ErrTLSRequired
  71. }
  72. input, _ := reflectType.FieldByName("input")
  73. rawInput, _ := reflectType.FieldByName("rawInput")
  74. return &VisionConn{
  75. Conn: conn,
  76. reader: bufio.NewChunkReader(conn, xrayChunkSize),
  77. writer: bufio.NewVectorisedWriter(conn),
  78. input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
  79. rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
  80. netConn: netConn,
  81. logger: logger,
  82. userUUID: userUUID,
  83. numberOfPacketToFilter: 8,
  84. remainingServerHello: -1,
  85. isPadding: true,
  86. writeUUID: true,
  87. withinPaddingBuffers: true,
  88. remainingContent: -1,
  89. remainingPadding: -1,
  90. }, nil
  91. }
  92. func (c *VisionConn) Read(p []byte) (n int, err error) {
  93. if c.remainingReader != nil {
  94. n, err = c.remainingReader.Read(p)
  95. if err == io.EOF {
  96. c.remainingReader = nil
  97. }
  98. if n > 0 {
  99. return
  100. }
  101. }
  102. if c.directRead {
  103. return c.netConn.Read(p)
  104. }
  105. var bufferBytes []byte
  106. if len(p) > xrayChunkSize {
  107. n, err = c.Conn.Read(p)
  108. if err != nil {
  109. return
  110. }
  111. bufferBytes = p[:n]
  112. } else {
  113. buffer, err := c.reader.ReadChunk()
  114. if err != nil {
  115. return 0, err
  116. }
  117. defer buffer.FullReset()
  118. bufferBytes = buffer.Bytes()
  119. }
  120. if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
  121. buffers := c.unPadding(bufferBytes)
  122. if c.remainingContent == 0 && c.remainingPadding == 0 {
  123. if c.currentCommand == 1 {
  124. c.withinPaddingBuffers = false
  125. c.remainingContent = -1
  126. c.remainingPadding = -1
  127. } else if c.currentCommand == 2 {
  128. c.withinPaddingBuffers = false
  129. c.directRead = true
  130. inputBuffer, err := io.ReadAll(c.input)
  131. if err != nil {
  132. return 0, err
  133. }
  134. buffers = append(buffers, inputBuffer)
  135. rawInputBuffer, err := io.ReadAll(c.rawInput)
  136. if err != nil {
  137. return 0, err
  138. }
  139. buffers = append(buffers, rawInputBuffer)
  140. c.logger.Trace("XtlsRead readV")
  141. } else if c.currentCommand == 0 {
  142. c.withinPaddingBuffers = true
  143. } else {
  144. return 0, E.New("unknown command ", c.currentCommand)
  145. }
  146. } else if c.remainingContent > 0 || c.remainingPadding > 0 {
  147. c.withinPaddingBuffers = true
  148. } else {
  149. c.withinPaddingBuffers = false
  150. }
  151. if c.numberOfPacketToFilter > 0 {
  152. c.filterTLS(buffers)
  153. }
  154. c.remainingReader = io.MultiReader(common.Map(buffers, func(it []byte) io.Reader { return bytes.NewReader(it) })...)
  155. return c.Read(p)
  156. } else {
  157. if c.numberOfPacketToFilter > 0 {
  158. c.filterTLS([][]byte{bufferBytes})
  159. }
  160. return
  161. }
  162. }
  163. func (c *VisionConn) Write(p []byte) (n int, err error) {
  164. if c.numberOfPacketToFilter > 0 {
  165. c.filterTLS([][]byte{p})
  166. }
  167. if c.isPadding {
  168. inputLen := len(p)
  169. buffers := reshapeBuffer(p)
  170. var specIndex int
  171. for i, buffer := range buffers {
  172. if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
  173. var command byte = commandPaddingEnd
  174. if c.enableXTLS {
  175. c.directWrite = true
  176. specIndex = i
  177. command = commandPaddingDirect
  178. }
  179. c.isPadding = false
  180. buffers[i] = c.padding(buffer, command)
  181. break
  182. } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
  183. c.isPadding = false
  184. buffers[i] = c.padding(buffer, commandPaddingEnd)
  185. break
  186. }
  187. buffers[i] = c.padding(buffer, commandPaddingContinue)
  188. }
  189. if c.directWrite {
  190. encryptedBuffer := buffers[:specIndex+1]
  191. err = c.writer.WriteVectorised(encryptedBuffer)
  192. if err != nil {
  193. return
  194. }
  195. buffers = buffers[specIndex+1:]
  196. c.writer = bufio.NewVectorisedWriter(c.netConn)
  197. c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
  198. time.Sleep(5 * time.Millisecond) // wtf
  199. }
  200. err = c.writer.WriteVectorised(buffers)
  201. if err == nil {
  202. n = inputLen
  203. }
  204. return
  205. }
  206. if c.directWrite {
  207. return c.netConn.Write(p)
  208. } else {
  209. return c.Conn.Write(p)
  210. }
  211. }
  212. func (c *VisionConn) filterTLS(buffers [][]byte) {
  213. for _, buffer := range buffers {
  214. c.numberOfPacketToFilter--
  215. if len(buffer) > 6 {
  216. if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
  217. c.isTLS = true
  218. if buffer[5] == 2 {
  219. c.isTLS12orAbove = true
  220. c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
  221. if len(buffer) >= 79 && c.remainingServerHello >= 79 {
  222. sessionIdLen := int32(buffer[43])
  223. cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
  224. c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
  225. } else {
  226. c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
  227. }
  228. }
  229. } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
  230. c.isTLS = true
  231. c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
  232. }
  233. }
  234. if c.remainingServerHello > 0 {
  235. end := int(c.remainingServerHello)
  236. if end > len(buffer) {
  237. end = len(buffer)
  238. }
  239. c.remainingServerHello -= int32(end)
  240. if bytes.Contains(buffer[:end], tls13SupportedVersions) {
  241. cipher, ok := tls13CipherSuiteDic[c.cipher]
  242. if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
  243. c.enableXTLS = true
  244. }
  245. c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
  246. c.numberOfPacketToFilter = 0
  247. return
  248. } else if c.remainingServerHello == 0 {
  249. c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
  250. c.numberOfPacketToFilter = 0
  251. return
  252. }
  253. }
  254. if c.numberOfPacketToFilter == 0 {
  255. c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
  256. }
  257. }
  258. }
  259. func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
  260. contentLen := 0
  261. paddingLen := 0
  262. if buffer != nil {
  263. contentLen = buffer.Len()
  264. }
  265. if contentLen < 900 && c.isTLS {
  266. l, _ := rand.Int(rand.Reader, big.NewInt(500))
  267. paddingLen = int(l.Int64()) + 900 - contentLen
  268. } else {
  269. l, _ := rand.Int(rand.Reader, big.NewInt(256))
  270. paddingLen = int(l.Int64())
  271. }
  272. var bufferLen int
  273. if c.writeUUID {
  274. bufferLen += 16
  275. }
  276. bufferLen += 5
  277. if buffer != nil {
  278. bufferLen += buffer.Len()
  279. }
  280. bufferLen += paddingLen
  281. newBuffer := buf.NewSize(bufferLen)
  282. if c.writeUUID {
  283. common.Must1(newBuffer.Write(c.userUUID[:]))
  284. c.writeUUID = false
  285. }
  286. common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
  287. if buffer != nil {
  288. common.Must1(newBuffer.Write(buffer.Bytes()))
  289. buffer.Release()
  290. }
  291. newBuffer.Extend(paddingLen)
  292. c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
  293. return newBuffer
  294. }
  295. func (c *VisionConn) unPadding(buffer []byte) [][]byte {
  296. var bufferIndex int
  297. if c.remainingContent == -1 && c.remainingPadding == -1 {
  298. if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
  299. bufferIndex = 16
  300. c.remainingContent = 0
  301. c.remainingPadding = 0
  302. c.currentCommand = 0
  303. }
  304. }
  305. if c.remainingContent == -1 && c.remainingPadding == -1 {
  306. return [][]byte{buffer}
  307. }
  308. var buffers [][]byte
  309. for bufferIndex < len(buffer) {
  310. if c.remainingContent <= 0 && c.remainingPadding <= 0 {
  311. if c.currentCommand == 1 {
  312. buffers = append(buffers, buffer[bufferIndex:])
  313. break
  314. } else {
  315. paddingInfo := buffer[bufferIndex : bufferIndex+5]
  316. c.currentCommand = int(paddingInfo[0])
  317. c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
  318. c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
  319. bufferIndex += 5
  320. c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
  321. }
  322. } else if c.remainingContent > 0 {
  323. end := c.remainingContent
  324. if end > len(buffer)-bufferIndex {
  325. end = len(buffer) - bufferIndex
  326. }
  327. buffers = append(buffers, buffer[bufferIndex:bufferIndex+end])
  328. c.remainingContent -= end
  329. bufferIndex += end
  330. } else {
  331. end := c.remainingPadding
  332. if end > len(buffer)-bufferIndex {
  333. end = len(buffer) - bufferIndex
  334. }
  335. c.remainingPadding -= end
  336. bufferIndex += end
  337. }
  338. if bufferIndex == len(buffer) {
  339. break
  340. }
  341. }
  342. return buffers
  343. }
  344. func (c *VisionConn) Upstream() any {
  345. return c.Conn
  346. }