| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- package vless
- import (
- "bytes"
- "crypto/rand"
- "crypto/tls"
- "io"
- "math/big"
- "net"
- "reflect"
- "time"
- "unsafe"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/buf"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- "github.com/sagernet/sing/common/logger"
- N "github.com/sagernet/sing/common/network"
- )
- var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
- func init() {
- tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
- tlsConn, loaded := conn.(*tls.Conn)
- if !loaded {
- return
- }
- return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
- })
- }
- const xrayChunkSize = 8192
- type VisionConn struct {
- net.Conn
- reader *bufio.ChunkReader
- writer N.VectorisedWriter
- input *bytes.Reader
- rawInput *bytes.Buffer
- netConn net.Conn
- logger logger.Logger
- userUUID [16]byte
- isTLS bool
- numberOfPacketToFilter int
- isTLS12orAbove bool
- remainingServerHello int32
- cipher uint16
- enableXTLS bool
- isPadding bool
- directWrite bool
- writeUUID bool
- withinPaddingBuffers bool
- remainingContent int
- remainingPadding int
- currentCommand int
- directRead bool
- remainingReader io.Reader
- }
- func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
- var (
- loaded bool
- reflectType reflect.Type
- reflectPointer uintptr
- netConn net.Conn
- )
- for _, tlsCreator := range tlsRegistry {
- loaded, netConn, reflectType, reflectPointer = tlsCreator(conn)
- if loaded {
- break
- }
- }
- if !loaded {
- return nil, C.ErrTLSRequired
- }
- input, _ := reflectType.FieldByName("input")
- rawInput, _ := reflectType.FieldByName("rawInput")
- return &VisionConn{
- Conn: conn,
- reader: bufio.NewChunkReader(conn, xrayChunkSize),
- writer: bufio.NewVectorisedWriter(conn),
- input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
- rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
- netConn: netConn,
- logger: logger,
- userUUID: userUUID,
- numberOfPacketToFilter: 8,
- remainingServerHello: -1,
- isPadding: true,
- writeUUID: true,
- withinPaddingBuffers: true,
- remainingContent: -1,
- remainingPadding: -1,
- }, nil
- }
- func (c *VisionConn) Read(p []byte) (n int, err error) {
- if c.remainingReader != nil {
- n, err = c.remainingReader.Read(p)
- if err == io.EOF {
- c.remainingReader = nil
- }
- if n > 0 {
- return
- }
- }
- if c.directRead {
- return c.netConn.Read(p)
- }
- var bufferBytes []byte
- if len(p) > xrayChunkSize {
- n, err = c.Conn.Read(p)
- if err != nil {
- return
- }
- bufferBytes = p[:n]
- } else {
- buffer, err := c.reader.ReadChunk()
- if err != nil {
- return 0, err
- }
- defer buffer.FullReset()
- bufferBytes = buffer.Bytes()
- }
- if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
- buffers := c.unPadding(bufferBytes)
- if c.remainingContent == 0 && c.remainingPadding == 0 {
- if c.currentCommand == 1 {
- c.withinPaddingBuffers = false
- c.remainingContent = -1
- c.remainingPadding = -1
- } else if c.currentCommand == 2 {
- c.withinPaddingBuffers = false
- c.directRead = true
- inputBuffer, err := io.ReadAll(c.input)
- if err != nil {
- return 0, err
- }
- buffers = append(buffers, inputBuffer)
- rawInputBuffer, err := io.ReadAll(c.rawInput)
- if err != nil {
- return 0, err
- }
- buffers = append(buffers, rawInputBuffer)
- c.logger.Trace("XtlsRead readV")
- } else if c.currentCommand == 0 {
- c.withinPaddingBuffers = true
- } else {
- return 0, E.New("unknown command ", c.currentCommand)
- }
- } else if c.remainingContent > 0 || c.remainingPadding > 0 {
- c.withinPaddingBuffers = true
- } else {
- c.withinPaddingBuffers = false
- }
- if c.numberOfPacketToFilter > 0 {
- c.filterTLS(buffers)
- }
- c.remainingReader = io.MultiReader(common.Map(buffers, func(it []byte) io.Reader { return bytes.NewReader(it) })...)
- return c.Read(p)
- } else {
- if c.numberOfPacketToFilter > 0 {
- c.filterTLS([][]byte{bufferBytes})
- }
- return
- }
- }
- func (c *VisionConn) Write(p []byte) (n int, err error) {
- if c.numberOfPacketToFilter > 0 {
- c.filterTLS([][]byte{p})
- }
- if c.isPadding {
- inputLen := len(p)
- buffers := reshapeBuffer(p)
- var specIndex int
- for i, buffer := range buffers {
- if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
- var command byte = commandPaddingEnd
- if c.enableXTLS {
- c.directWrite = true
- specIndex = i
- command = commandPaddingDirect
- }
- c.isPadding = false
- buffers[i] = c.padding(buffer, command)
- break
- } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
- c.isPadding = false
- buffers[i] = c.padding(buffer, commandPaddingEnd)
- break
- }
- buffers[i] = c.padding(buffer, commandPaddingContinue)
- }
- if c.directWrite {
- encryptedBuffer := buffers[:specIndex+1]
- err = c.writer.WriteVectorised(encryptedBuffer)
- if err != nil {
- return
- }
- buffers = buffers[specIndex+1:]
- c.writer = bufio.NewVectorisedWriter(c.netConn)
- c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
- time.Sleep(5 * time.Millisecond) // wtf
- }
- err = c.writer.WriteVectorised(buffers)
- if err == nil {
- n = inputLen
- }
- return
- }
- if c.directWrite {
- return c.netConn.Write(p)
- } else {
- return c.Conn.Write(p)
- }
- }
- func (c *VisionConn) filterTLS(buffers [][]byte) {
- for _, buffer := range buffers {
- c.numberOfPacketToFilter--
- if len(buffer) > 6 {
- if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
- c.isTLS = true
- if buffer[5] == 2 {
- c.isTLS12orAbove = true
- c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
- if len(buffer) >= 79 && c.remainingServerHello >= 79 {
- sessionIdLen := int32(buffer[43])
- cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
- c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
- } else {
- c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
- }
- }
- } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
- c.isTLS = true
- c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
- }
- }
- if c.remainingServerHello > 0 {
- end := int(c.remainingServerHello)
- if end > len(buffer) {
- end = len(buffer)
- }
- c.remainingServerHello -= int32(end)
- if bytes.Contains(buffer[:end], tls13SupportedVersions) {
- cipher, ok := tls13CipherSuiteDic[c.cipher]
- if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
- c.enableXTLS = true
- }
- c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
- c.numberOfPacketToFilter = 0
- return
- } else if c.remainingServerHello == 0 {
- c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
- c.numberOfPacketToFilter = 0
- return
- }
- }
- if c.numberOfPacketToFilter == 0 {
- c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
- }
- }
- }
- func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
- contentLen := 0
- paddingLen := 0
- if buffer != nil {
- contentLen = buffer.Len()
- }
- if contentLen < 900 && c.isTLS {
- l, _ := rand.Int(rand.Reader, big.NewInt(500))
- paddingLen = int(l.Int64()) + 900 - contentLen
- } else {
- l, _ := rand.Int(rand.Reader, big.NewInt(256))
- paddingLen = int(l.Int64())
- }
- var bufferLen int
- if c.writeUUID {
- bufferLen += 16
- }
- bufferLen += 5
- if buffer != nil {
- bufferLen += buffer.Len()
- }
- bufferLen += paddingLen
- newBuffer := buf.NewSize(bufferLen)
- if c.writeUUID {
- common.Must1(newBuffer.Write(c.userUUID[:]))
- c.writeUUID = false
- }
- common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
- if buffer != nil {
- common.Must1(newBuffer.Write(buffer.Bytes()))
- buffer.Release()
- }
- newBuffer.Extend(paddingLen)
- c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
- return newBuffer
- }
- func (c *VisionConn) unPadding(buffer []byte) [][]byte {
- var bufferIndex int
- if c.remainingContent == -1 && c.remainingPadding == -1 {
- if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
- bufferIndex = 16
- c.remainingContent = 0
- c.remainingPadding = 0
- c.currentCommand = 0
- }
- }
- if c.remainingContent == -1 && c.remainingPadding == -1 {
- return [][]byte{buffer}
- }
- var buffers [][]byte
- for bufferIndex < len(buffer) {
- if c.remainingContent <= 0 && c.remainingPadding <= 0 {
- if c.currentCommand == 1 {
- buffers = append(buffers, buffer[bufferIndex:])
- break
- } else {
- paddingInfo := buffer[bufferIndex : bufferIndex+5]
- c.currentCommand = int(paddingInfo[0])
- c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
- c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
- bufferIndex += 5
- c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
- }
- } else if c.remainingContent > 0 {
- end := c.remainingContent
- if end > len(buffer)-bufferIndex {
- end = len(buffer) - bufferIndex
- }
- buffers = append(buffers, buffer[bufferIndex:bufferIndex+end])
- c.remainingContent -= end
- bufferIndex += end
- } else {
- end := c.remainingPadding
- if end > len(buffer)-bufferIndex {
- end = len(buffer) - bufferIndex
- }
- c.remainingPadding -= end
- bufferIndex += end
- }
- if bufferIndex == len(buffer) {
- break
- }
- }
- return buffers
- }
- func (c *VisionConn) Upstream() any {
- return c.Conn
- }
|