Przeglądaj źródła

Update vision protocol

世界 2 lat temu
rodzic
commit
e4bff0460d

+ 1 - 1
inbound/vless.go

@@ -50,7 +50,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
 		ctx:   ctx,
 		users: options.Users,
 	}
-	service := vless.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound))
+	service := vless.NewService[int](logger, adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound))
 	service.UpdateUsers(common.MapIndexed(inbound.users, func(index int, _ option.VLESSUser) int {
 		return index
 	}), common.Map(inbound.users, func(it option.VLESSUser) string {

+ 1 - 1
outbound/vless.go

@@ -67,7 +67,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg
 	default:
 		return nil, E.New("unknown packet encoding: ", options.PacketEncoding)
 	}
-	outbound.client, err = vless.NewClient(options.UUID, options.Flow)
+	outbound.client, err = vless.NewClient(options.UUID, options.Flow, logger)
 	if err != nil {
 		return nil, err
 	}

+ 7 - 5
transport/vless/client.go

@@ -9,6 +9,7 @@ import (
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
@@ -16,11 +17,12 @@ import (
 )
 
 type Client struct {
-	key  [16]byte
-	flow string
+	key    [16]byte
+	flow   string
+	logger logger.Logger
 }
 
-func NewClient(userId string, flow string) (*Client, error) {
+func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) {
 	user := uuid.FromStringOrNil(userId)
 	if user == uuid.Nil {
 		user = uuid.NewV5(user, userId)
@@ -30,12 +32,12 @@ func NewClient(userId string, flow string) (*Client, error) {
 	default:
 		return nil, E.New("unsupported flow: " + flow)
 	}
-	return &Client{user, flow}, nil
+	return &Client{user, flow, logger}, nil
 }
 
 func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) {
 	if c.flow == FlowVision {
-		vConn, err := NewVisionConn(conn, c.key)
+		vConn, err := NewVisionConn(conn, c.key, c.logger)
 		if err != nil {
 			return nil, E.Cause(err, "initialize vision")
 		}

+ 4 - 0
transport/vless/constant.go

@@ -11,6 +11,10 @@ var (
 	tlsClientHandShakeStart = []byte{0x16, 0x03}
 	tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
 	tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
+
+	commandPaddingContinue byte = 0
+	commandPaddingEnd      byte = 1
+	commandPaddingDirect   byte = 2
 )
 
 var tls13CipherSuiteDic = map[uint16]string{

+ 5 - 2
transport/vless/service.go

@@ -11,6 +11,7 @@ import (
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
@@ -19,6 +20,7 @@ import (
 
 type Service[T any] struct {
 	userMap map[[16]byte]T
+	logger  logger.Logger
 	handler Handler
 }
 
@@ -28,8 +30,9 @@ type Handler interface {
 	E.Handler
 }
 
-func NewService[T any](handler Handler) *Service[T] {
+func NewService[T any](logger logger.Logger, handler Handler) *Service[T] {
 	return &Service[T]{
+		logger:  logger,
 		handler: handler,
 	}
 }
@@ -64,7 +67,7 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata
 	switch request.Flow {
 	case "":
 	case FlowVision:
-		protocolConn, err = NewVisionConn(conn, request.UUID)
+		protocolConn, err = NewVisionConn(conn, request.UUID, s.logger)
 		if err != nil {
 			return E.Cause(err, "initialize vision")
 		}

+ 74 - 44
transport/vless/vision.go

@@ -16,6 +16,7 @@ import (
 	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
 	N "github.com/sagernet/sing/common/network"
 )
 
@@ -37,26 +38,27 @@ type VisionConn struct {
 	input    *bytes.Reader
 	rawInput *bytes.Buffer
 	netConn  net.Conn
+	logger   logger.Logger
 
-	userUUID                 [16]byte
-	isTLS                    bool
-	numberOfPacketToFilter   int
-	isTLS12orAbove           bool
-	remainingServerHello     int32
-	cipher                   uint16
-	enableXTLS               bool
-	filterTlsApplicationData bool
-	directWrite              bool
-	writeUUID                bool
-	filterUUID               bool
-	remainingContent         int
-	remainingPadding         int
-	currentCommand           int
-	directRead               bool
-	remainingReader          io.Reader
+	userUUID               [16]byte
+	isTLS                  bool
+	numberOfPacketToFilter int
+	isTLS12orAbove         bool
+	remainingServerHello   int32
+	cipher                 uint16
+	enableXTLS             bool
+	isPadding              bool
+	directWrite            bool
+	writeUUID              bool
+	withinPaddingBuffers   bool
+	remainingContent       int
+	remainingPadding       int
+	currentCommand         int
+	directRead             bool
+	remainingReader        io.Reader
 }
 
-func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) {
+func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
 	var (
 		loaded         bool
 		reflectType    reflect.Type
@@ -75,19 +77,21 @@ func NewVisionConn(conn net.Conn, userUUID [16]byte) (*VisionConn, error) {
 	input, _ := reflectType.FieldByName("input")
 	rawInput, _ := reflectType.FieldByName("rawInput")
 	return &VisionConn{
-		Conn:                     conn,
-		writer:                   bufio.NewVectorisedWriter(conn),
-		input:                    (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
-		rawInput:                 (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
-		netConn:                  netConn,
-		userUUID:                 userUUID,
-		numberOfPacketToFilter:   8,
-		remainingServerHello:     -1,
-		filterTlsApplicationData: true,
-		writeUUID:                true,
-		filterUUID:               true,
-		remainingContent:         -1,
-		remainingPadding:         -1,
+		Conn:     conn,
+		writer:   bufio.NewVectorisedWriter(conn),
+		input:    (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
+		rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
+		netConn:  netConn,
+		logger:   logger,
+
+		userUUID:               userUUID,
+		numberOfPacketToFilter: 8,
+		remainingServerHello:   -1,
+		isPadding:              true,
+		writeUUID:              true,
+		withinPaddingBuffers:   true,
+		remainingContent:       -1,
+		remainingPadding:       -1,
 	}, nil
 }
 
@@ -97,6 +101,7 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
 		if err == io.EOF {
 			c.remainingReader = nil
 			if n > 0 {
+				err = nil
 				return
 			}
 		}
@@ -109,13 +114,15 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
 		return
 	}
 	buffer := p[:n]
-	if c.filterUUID && (c.isTLS || c.numberOfPacketToFilter > 0) {
+	if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
 		buffers := c.unPadding(buffer)
 		if c.remainingContent == 0 && c.remainingPadding == 0 {
 			if c.currentCommand == 1 {
-				c.filterUUID = false
+				c.withinPaddingBuffers = false
+				c.remainingContent = -1
+				c.remainingPadding = -1
 			} else if c.currentCommand == 2 {
-				c.filterUUID = false
+				c.withinPaddingBuffers = false
 				c.directRead = true
 
 				inputBuffer, err := io.ReadAll(c.input)
@@ -130,9 +137,17 @@ func (c *VisionConn) Read(p []byte) (n int, err error) {
 				}
 
 				buffers = append(buffers, rawInputBuffer)
-			} else if c.currentCommand != 0 {
+
+				c.logger.Trace("XtlsRead readV")
+			} else if c.currentCommand == 0 {
+				c.withinPaddingBuffers = true
+			} else {
 				return 0, E.New("unknown command ", c.currentCommand)
 			}
+		} else if c.remainingContent > 0 || c.remainingPadding > 0 {
+			c.withinPaddingBuffers = true
+		} else {
+			c.withinPaddingBuffers = false
 		}
 		if c.numberOfPacketToFilter > 0 {
 			c.filterTLS(buffers)
@@ -151,27 +166,27 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
 	if c.numberOfPacketToFilter > 0 {
 		c.filterTLS([][]byte{p})
 	}
-	if c.isTLS && c.filterTlsApplicationData {
+	if c.isPadding {
 		inputLen := len(p)
 		buffers := reshapeBuffer(p)
 		var specIndex int
 		for i, buffer := range buffers {
-			if buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
-				var command byte = 1
+			if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
+				var command byte = commandPaddingEnd
 				if c.enableXTLS {
 					c.directWrite = true
 					specIndex = i
-					command = 2
+					command = commandPaddingDirect
 				}
-				c.filterTlsApplicationData = false
+				c.isPadding = false
 				buffers[i] = c.padding(buffer, command)
 				break
-			} else if !c.isTLS12orAbove && c.numberOfPacketToFilter == 0 {
-				c.filterTlsApplicationData = false
-				buffers[i] = c.padding(buffer, 0x01)
+			} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
+				c.isPadding = false
+				buffers[i] = c.padding(buffer, commandPaddingEnd)
 				break
 			}
-			buffers[i] = c.padding(buffer, 0x00)
+			buffers[i] = c.padding(buffer, commandPaddingContinue)
 		}
 		if c.directWrite {
 			encryptedBuffer := buffers[:specIndex+1]
@@ -181,6 +196,7 @@ func (c *VisionConn) Write(p []byte) (n int, err error) {
 			}
 			buffers = buffers[specIndex+1:]
 			c.writer = bufio.NewVectorisedWriter(c.netConn)
+			c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
 			time.Sleep(5 * time.Millisecond) // wtf
 		}
 		err = c.writer.WriteVectorised(buffers)
@@ -209,10 +225,13 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
 						sessionIdLen := int32(buffer[43])
 						cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
 						c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
+					} else {
+						c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
 					}
 				}
 			} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
 				c.isTLS = true
+				c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
 			}
 		}
 		if c.remainingServerHello > 0 {
@@ -226,13 +245,18 @@ func (c *VisionConn) filterTLS(buffers [][]byte) {
 				if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
 					c.enableXTLS = true
 				}
+				c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
 				c.numberOfPacketToFilter = 0
 				return
 			} else if c.remainingServerHello == 0 {
+				c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
 				c.numberOfPacketToFilter = 0
 				return
 			}
 		}
+		if c.numberOfPacketToFilter == 0 {
+			c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
+		}
 	}
 }
 
@@ -242,9 +266,12 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
 	if buffer != nil {
 		contentLen = buffer.Len()
 	}
-	if contentLen < 900 {
+	if contentLen < 900 && c.isTLS {
 		l, _ := rand.Int(rand.Reader, big.NewInt(500))
 		paddingLen = int(l.Int64()) + 900 - contentLen
+	} else {
+		l, _ := rand.Int(rand.Reader, big.NewInt(256))
+		paddingLen = int(l.Int64())
 	}
 	newBuffer := buf.New()
 	if c.writeUUID {
@@ -257,6 +284,7 @@ func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
 		buffer.Release()
 	}
 	newBuffer.Extend(paddingLen)
+	c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
 	return newBuffer
 }
 
@@ -267,6 +295,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
 			bufferIndex = 16
 			c.remainingContent = 0
 			c.remainingPadding = 0
+			c.currentCommand = 0
 		}
 	}
 	if c.remainingContent == -1 && c.remainingPadding == -1 {
@@ -284,6 +313,7 @@ func (c *VisionConn) unPadding(buffer []byte) [][]byte {
 				c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
 				c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
 				bufferIndex += 5
+				c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
 			}
 		} else if c.remainingContent > 0 {
 			end := c.remainingContent