vision.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package vless
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "crypto/tls"
  6. "io"
  7. "math/big"
  8. "net"
  9. "reflect"
  10. "time"
  11. "unsafe"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing/common"
  14. "github.com/sagernet/sing/common/buf"
  15. "github.com/sagernet/sing/common/bufio"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. "github.com/sagernet/sing/common/logger"
  18. N "github.com/sagernet/sing/common/network"
  19. )
  20. var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
  21. func init() {
  22. tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
  23. tlsConn, loaded := conn.(*tls.Conn)
  24. if !loaded {
  25. return
  26. }
  27. return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
  28. })
  29. }
  30. const xrayChunkSize = 8192
  31. type VisionConn struct {
  32. net.Conn
  33. reader *bufio.ChunkReader
  34. writer N.VectorisedWriter
  35. input *bytes.Reader
  36. rawInput *bytes.Buffer
  37. netConn net.Conn
  38. logger logger.Logger
  39. userUUID [16]byte
  40. isTLS bool
  41. numberOfPacketToFilter int
  42. isTLS12orAbove bool
  43. remainingServerHello int32
  44. cipher uint16
  45. enableXTLS bool
  46. isPadding bool
  47. directWrite bool
  48. writeUUID bool
  49. withinPaddingBuffers bool
  50. remainingContent int
  51. remainingPadding int
  52. currentCommand byte
  53. directRead bool
  54. remainingReader io.Reader
  55. }
  56. func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
  57. var (
  58. loaded bool
  59. reflectType reflect.Type
  60. reflectPointer uintptr
  61. netConn net.Conn
  62. )
  63. for _, tlsCreator := range tlsRegistry {
  64. loaded, netConn, reflectType, reflectPointer = tlsCreator(tlsConn)
  65. if loaded {
  66. break
  67. }
  68. }
  69. if !loaded {
  70. return nil, C.ErrTLSRequired
  71. }
  72. input, _ := reflectType.FieldByName("input")
  73. rawInput, _ := reflectType.FieldByName("rawInput")
  74. return &VisionConn{
  75. Conn: conn,
  76. reader: bufio.NewChunkReader(conn, xrayChunkSize),
  77. writer: bufio.NewVectorisedWriter(conn),
  78. input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
  79. rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
  80. netConn: netConn,
  81. logger: logger,
  82. userUUID: userUUID,
  83. numberOfPacketToFilter: 8,
  84. remainingServerHello: -1,
  85. isPadding: true,
  86. writeUUID: true,
  87. withinPaddingBuffers: true,
  88. remainingContent: -1,
  89. remainingPadding: -1,
  90. }, nil
  91. }
  92. func (c *VisionConn) Read(p []byte) (n int, err error) {
  93. if c.remainingReader != nil {
  94. n, err = c.remainingReader.Read(p)
  95. if err == io.EOF {
  96. err = nil
  97. c.remainingReader = nil
  98. }
  99. if n > 0 {
  100. return
  101. }
  102. }
  103. if c.directRead {
  104. return c.netConn.Read(p)
  105. }
  106. var bufferBytes []byte
  107. var chunkBuffer *buf.Buffer
  108. if len(p) > xrayChunkSize {
  109. n, err = c.Conn.Read(p)
  110. if err != nil {
  111. return
  112. }
  113. bufferBytes = p[:n]
  114. } else {
  115. chunkBuffer, err = c.reader.ReadChunk()
  116. if err != nil {
  117. return 0, err
  118. }
  119. bufferBytes = chunkBuffer.Bytes()
  120. }
  121. if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
  122. buffers := c.unPadding(bufferBytes)
  123. if chunkBuffer != nil {
  124. buffers = common.Map(buffers, func(it *buf.Buffer) *buf.Buffer {
  125. return it.ToOwned()
  126. })
  127. chunkBuffer.FullReset()
  128. }
  129. if c.remainingContent == 0 && c.remainingPadding == 0 {
  130. if c.currentCommand == commandPaddingEnd {
  131. c.withinPaddingBuffers = false
  132. c.remainingContent = -1
  133. c.remainingPadding = -1
  134. } else if c.currentCommand == commandPaddingDirect {
  135. c.withinPaddingBuffers = false
  136. c.directRead = true
  137. inputBuffer, err := io.ReadAll(c.input)
  138. if err != nil {
  139. return 0, err
  140. }
  141. buffers = append(buffers, buf.As(inputBuffer))
  142. rawInputBuffer, err := io.ReadAll(c.rawInput)
  143. if err != nil {
  144. return 0, err
  145. }
  146. buffers = append(buffers, buf.As(rawInputBuffer))
  147. c.logger.Trace("XtlsRead readV")
  148. } else if c.currentCommand == commandPaddingContinue {
  149. c.withinPaddingBuffers = true
  150. } else {
  151. return 0, E.New("unknown command ", c.currentCommand)
  152. }
  153. } else if c.remainingContent > 0 || c.remainingPadding > 0 {
  154. c.withinPaddingBuffers = true
  155. } else {
  156. c.withinPaddingBuffers = false
  157. }
  158. if c.numberOfPacketToFilter > 0 {
  159. c.filterTLS(buf.ToSliceMulti(buffers))
  160. }
  161. c.remainingReader = io.MultiReader(common.Map(buffers, func(it *buf.Buffer) io.Reader { return it })...)
  162. return c.Read(p)
  163. } else {
  164. if c.numberOfPacketToFilter > 0 {
  165. c.filterTLS([][]byte{bufferBytes})
  166. }
  167. if chunkBuffer != nil {
  168. n = copy(p, bufferBytes)
  169. chunkBuffer.Advance(n)
  170. }
  171. return
  172. }
  173. }
  174. func (c *VisionConn) Write(p []byte) (n int, err error) {
  175. if c.numberOfPacketToFilter > 0 {
  176. c.filterTLS([][]byte{p})
  177. }
  178. if c.isPadding {
  179. inputLen := len(p)
  180. buffers := reshapeBuffer(p)
  181. var specIndex int
  182. for i, buffer := range buffers {
  183. if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
  184. var command byte = commandPaddingEnd
  185. if c.enableXTLS {
  186. c.directWrite = true
  187. specIndex = i
  188. command = commandPaddingDirect
  189. }
  190. c.isPadding = false
  191. buffers[i] = c.padding(buffer, command)
  192. break
  193. } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
  194. c.isPadding = false
  195. buffers[i] = c.padding(buffer, commandPaddingEnd)
  196. break
  197. }
  198. buffers[i] = c.padding(buffer, commandPaddingContinue)
  199. }
  200. if c.directWrite {
  201. encryptedBuffer := buffers[:specIndex+1]
  202. err = c.writer.WriteVectorised(encryptedBuffer)
  203. if err != nil {
  204. return
  205. }
  206. buffers = buffers[specIndex+1:]
  207. c.writer = bufio.NewVectorisedWriter(c.netConn)
  208. c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
  209. time.Sleep(5 * time.Millisecond) // wtf
  210. }
  211. err = c.writer.WriteVectorised(buffers)
  212. if err == nil {
  213. n = inputLen
  214. }
  215. return
  216. }
  217. if c.directWrite {
  218. return c.netConn.Write(p)
  219. } else {
  220. return c.Conn.Write(p)
  221. }
  222. }
  223. func (c *VisionConn) filterTLS(buffers [][]byte) {
  224. for _, buffer := range buffers {
  225. c.numberOfPacketToFilter--
  226. if len(buffer) > 6 {
  227. if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
  228. c.isTLS = true
  229. if buffer[5] == 2 {
  230. c.isTLS12orAbove = true
  231. c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
  232. if len(buffer) >= 79 && c.remainingServerHello >= 79 {
  233. sessionIdLen := int32(buffer[43])
  234. cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
  235. c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
  236. } else {
  237. c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
  238. }
  239. }
  240. } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
  241. c.isTLS = true
  242. c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
  243. }
  244. }
  245. if c.remainingServerHello > 0 {
  246. end := int(c.remainingServerHello)
  247. if end > len(buffer) {
  248. end = len(buffer)
  249. }
  250. c.remainingServerHello -= int32(end)
  251. if bytes.Contains(buffer[:end], tls13SupportedVersions) {
  252. cipher, ok := tls13CipherSuiteDic[c.cipher]
  253. if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
  254. c.enableXTLS = true
  255. }
  256. c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
  257. c.numberOfPacketToFilter = 0
  258. return
  259. } else if c.remainingServerHello == 0 {
  260. c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
  261. c.numberOfPacketToFilter = 0
  262. return
  263. }
  264. }
  265. if c.numberOfPacketToFilter == 0 {
  266. c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
  267. }
  268. }
  269. }
  270. func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
  271. contentLen := 0
  272. paddingLen := 0
  273. if buffer != nil {
  274. contentLen = buffer.Len()
  275. }
  276. if contentLen < 900 && c.isTLS {
  277. l, _ := rand.Int(rand.Reader, big.NewInt(500))
  278. paddingLen = int(l.Int64()) + 900 - contentLen
  279. } else {
  280. l, _ := rand.Int(rand.Reader, big.NewInt(256))
  281. paddingLen = int(l.Int64())
  282. }
  283. var bufferLen int
  284. if c.writeUUID {
  285. bufferLen += 16
  286. }
  287. bufferLen += 5
  288. if buffer != nil {
  289. bufferLen += buffer.Len()
  290. }
  291. bufferLen += paddingLen
  292. newBuffer := buf.NewSize(bufferLen)
  293. if c.writeUUID {
  294. common.Must1(newBuffer.Write(c.userUUID[:]))
  295. c.writeUUID = false
  296. }
  297. common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
  298. if buffer != nil {
  299. common.Must1(newBuffer.Write(buffer.Bytes()))
  300. buffer.Release()
  301. }
  302. newBuffer.Extend(paddingLen)
  303. c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
  304. return newBuffer
  305. }
  306. func (c *VisionConn) unPadding(buffer []byte) []*buf.Buffer {
  307. var bufferIndex int
  308. if c.remainingContent == -1 && c.remainingPadding == -1 {
  309. if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
  310. bufferIndex = 16
  311. c.remainingContent = 0
  312. c.remainingPadding = 0
  313. c.currentCommand = 0
  314. }
  315. }
  316. if c.remainingContent == -1 && c.remainingPadding == -1 {
  317. return []*buf.Buffer{buf.As(buffer)}
  318. }
  319. var buffers []*buf.Buffer
  320. for bufferIndex < len(buffer) {
  321. if c.remainingContent <= 0 && c.remainingPadding <= 0 {
  322. if c.currentCommand == 1 {
  323. buffers = append(buffers, buf.As(buffer[bufferIndex:]))
  324. break
  325. } else {
  326. paddingInfo := buffer[bufferIndex : bufferIndex+5]
  327. c.currentCommand = paddingInfo[0]
  328. c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
  329. c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
  330. bufferIndex += 5
  331. c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
  332. }
  333. } else if c.remainingContent > 0 {
  334. end := c.remainingContent
  335. if end > len(buffer)-bufferIndex {
  336. end = len(buffer) - bufferIndex
  337. }
  338. buffers = append(buffers, buf.As(buffer[bufferIndex:bufferIndex+end]))
  339. c.remainingContent -= end
  340. bufferIndex += end
  341. } else {
  342. end := c.remainingPadding
  343. if end > len(buffer)-bufferIndex {
  344. end = len(buffer) - bufferIndex
  345. }
  346. c.remainingPadding -= end
  347. bufferIndex += end
  348. }
  349. if bufferIndex == len(buffer) {
  350. break
  351. }
  352. }
  353. return buffers
  354. }
  355. func (c *VisionConn) NeedAdditionalReadDeadline() bool {
  356. return true
  357. }
  358. func (c *VisionConn) Upstream() any {
  359. return c.Conn
  360. }