瀏覽代碼

Vision padding upgrade (#1646)

* Vision server allow multiple blocks of padding

* Fix Vision client to support multiple possible padding blocks

* Vision padding upgrade

- Now we have two types of padding: long (pad to 900-1400) and traditional (0-256)
- Long padding is applied to tls handshakes and first (empty) packet
- Traditional padding is applied to all beginning (7) packets of the connection (counted two-way)
- Since receiver changed its way to unpad buffer in fd6973b3c67a6e5a982734a8c288b56845b69cb9, we can freely extend padding packet length easily in the future
- Simplify code

* Adjust receiver withinPaddingBuffers

Now default withinPaddingBuffers = true to give it a chance to do unpadding

* Fix magic numbers for Vision
Thanks @H1JK

Thanks @RPRX for guidance
yuhan6665 2 年之前
父節點
當前提交
2d898480be
共有 3 個文件被更改,包括 63 次插入44 次删除
  1. 54 32
      proxy/vless/encoding/encoding.go
  2. 4 6
      proxy/vless/inbound/inbound.go
  3. 5 6
      proxy/vless/outbound/outbound.go

+ 54 - 32
proxy/vless/encoding/encoding.go

@@ -36,6 +36,23 @@ var (
 	tlsClientHandShakeStart = []byte{0x16, 0x03}
 	tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
 	tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
+
+	Tls13CipherSuiteDic = map[uint16]string{
+		0x1301: "TLS_AES_128_GCM_SHA256",
+		0x1302: "TLS_AES_256_GCM_SHA384",
+		0x1303: "TLS_CHACHA20_POLY1305_SHA256",
+		0x1304: "TLS_AES_128_CCM_SHA256",
+		0x1305: "TLS_AES_128_CCM_8_SHA256",
+	}
+)
+
+const (
+	tlsHandshakeTypeClientHello byte = 0x01
+	tlsHandshakeTypeServerHello byte = 0x02
+
+	CommandPaddingContinue byte = 0x00
+	CommandPaddingEnd      byte = 0x01
+	CommandPaddingDirect   byte = 0x02
 )
 
 var addrParser = protocol.NewAddressParser(
@@ -256,7 +273,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 ) error {
 	err := func() error {
 		var ct stats.Counter
-		filterUUID := true
+		withinPaddingBuffers := true
 		shouldSwitchToDirectCopy := false
 		var remainingContent int32 = -1
 		var remainingPadding int32 = -1
@@ -294,13 +311,15 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 			}
 			buffer, err := reader.ReadMultiBuffer()
 			if !buffer.IsEmpty() {
-				if filterUUID && (*isTLS || *numberOfPacketToFilter > 0) {
+				if withinPaddingBuffers || *numberOfPacketToFilter > 0 {
 					buffer = XtlsUnpadding(ctx, buffer, userUUID, &remainingContent, &remainingPadding, &currentCommand)
 					if remainingContent == 0 && remainingPadding == 0 {
 						if currentCommand == 1 {
-							filterUUID = false
+							withinPaddingBuffers = false
+							remainingContent = -1
+							remainingPadding = -1 // set to initial state to parse the next padding
 						} else if currentCommand == 2 {
-							filterUUID = false
+							withinPaddingBuffers = false
 							shouldSwitchToDirectCopy = true
 							// XTLS Vision processes struct TLS Conn's input and rawInput
 							if inputBuffer, err := buf.ReadFrom(input); err == nil {
@@ -313,9 +332,15 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 									buffer, _ = buf.MergeMulti(buffer, rawInputBuffer)
 								}
 							}
-						} else if currentCommand != 0 {
+						} else if currentCommand == 0 {
+							withinPaddingBuffers = true
+						} else {
 							newError("XtlsRead unknown command ", currentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
 						}
+					} else if remainingContent > 0 || remainingPadding > 0 {
+						withinPaddingBuffers = true
+					} else {
+						withinPaddingBuffers = false
 					}
 				}
 				if *numberOfPacketToFilter > 0 {
@@ -342,12 +367,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater
 
 // XtlsWrite filter and write xtls protocol
 func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter,
-	ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
+	ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool,
 	cipher *uint16, remainingServerHello *int32,
 ) error {
 	err := func() error {
 		var ct stats.Counter
-		filterTlsApplicationData := true
+		isPadding := true
 		shouldSwitchToDirectCopy := false
 		for {
 			buffer, err := reader.ReadMultiBuffer()
@@ -355,27 +380,26 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 				if *numberOfPacketToFilter > 0 {
 					XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx)
 				}
-				if filterTlsApplicationData && *isTLS {
+				if isPadding {
 					buffer = ReshapeMultiBuffer(ctx, buffer)
 					var xtlsSpecIndex int
 					for i, b := range buffer {
-						if b.Len() >= 6 && bytes.Equal(tlsApplicationDataStart, b.BytesTo(3)) {
-							var command byte = 0x01
+						if *isTLS && b.Len() >= 6 && bytes.Equal(tlsApplicationDataStart, b.BytesTo(3)) {
+							var command byte = CommandPaddingEnd
 							if *enableXtls {
 								shouldSwitchToDirectCopy = true
 								xtlsSpecIndex = i
-								command = 0x02
+								command = CommandPaddingDirect
 							}
-							filterTlsApplicationData = false
-							buffer[i] = XtlsPadding(b, command, userUUID, ctx)
+							isPadding = false
+							buffer[i] = XtlsPadding(b, command, nil, *isTLS, ctx)
 							break
-						} else if !*isTLS12orAbove && *numberOfPacketToFilter <= 0 {
-							// maybe tls 1.1 or 1.0
-							filterTlsApplicationData = false
-							buffer[i] = XtlsPadding(b, 0x01, userUUID, ctx)
+						} else if !*isTLS12orAbove && *numberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
+							isPadding = false
+							buffer[i] = XtlsPadding(b, CommandPaddingEnd, nil, *isTLS, ctx)
 							break
 						}
-						buffer[i] = XtlsPadding(b, 0x00, userUUID, ctx)
+						buffer[i] = XtlsPadding(b, CommandPaddingContinue, nil, *isTLS, ctx)
 					}
 					if shouldSwitchToDirectCopy {
 						encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1)
@@ -422,7 +446,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXt
 		*numberOfPacketToFilter--
 		if b.Len() >= 6 {
 			startsBytes := b.BytesTo(6)
-			if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == 0x02 {
+			if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == tlsHandshakeTypeServerHello {
 				*remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
 				*isTLS12orAbove = true
 				*isTLS = true
@@ -433,7 +457,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXt
 				} else {
 					newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx))
 				}
-			} else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == 0x01 {
+			} else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == tlsHandshakeTypeClientHello {
 				*isTLS = true
 				newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
 			}
@@ -483,7 +507,7 @@ func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBu
 	for i, buffer1 := range buffer {
 		if buffer1.Len() >= buf.Size-21 {
 			index := int32(bytes.LastIndex(buffer1.Bytes(), tlsApplicationDataStart))
-			if index <= 0 {
+			if index <= 0 || index > buf.Size-21 {
 				index = buf.Size / 2
 			}
 			buffer2 := buf.New()
@@ -503,23 +527,28 @@ func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBu
 }
 
 // XtlsPadding add padding to eliminate length siganature during tls handshake
-func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, ctx context.Context) *buf.Buffer {
+func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer {
 	var contantLen int32 = 0
 	var paddingLen int32 = 0
 	if b != nil {
 		contantLen = b.Len()
 	}
-	if contantLen < 900 {
+	if contantLen < 900 && longPadding {
 		l, err := rand.Int(rand.Reader, big.NewInt(500))
 		if err != nil {
 			newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
 		}
 		paddingLen = int32(l.Int64()) + 900 - contantLen
+	} else {
+		l, err := rand.Int(rand.Reader, big.NewInt(256))
+		if err != nil {
+			newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
+		}
+		paddingLen = int32(l.Int64())
 	}
 	newbuffer := buf.New()
 	if userUUID != nil {
 		newbuffer.Write(*userUUID)
-		*userUUID = nil
 	}
 	newbuffer.Write([]byte{command, byte(contantLen >> 8), byte(contantLen), byte(paddingLen >> 8), byte(paddingLen)})
 	if b != nil {
@@ -543,6 +572,7 @@ func XtlsUnpadding(ctx context.Context, buffer buf.MultiBuffer, userUUID []byte,
 				posByte = 16
 				*remainingContent = 0
 				*remainingPadding = 0
+				*currentCommand = 0
 				break
 			}
 		}
@@ -601,11 +631,3 @@ func XtlsUnpadding(ctx context.Context, buffer buf.MultiBuffer, userUUID []byte,
 	buf.ReleaseMulti(buffer)
 	return mb2
 }
-
-var Tls13CipherSuiteDic = map[uint16]string{
-	0x1301: "TLS_AES_128_GCM_SHA256",
-	0x1302: "TLS_AES_256_GCM_SHA384",
-	0x1303: "TLS_CHACHA20_POLY1305_SHA256",
-	0x1304: "TLS_AES_128_CCM_SHA256",
-	0x1305: "TLS_AES_128_CCM_8_SHA256",
-}

+ 4 - 6
proxy/vless/inbound/inbound.go

@@ -624,11 +624,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 		}
 		if requestAddons.Flow == vless.XRV {
 			encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
-			if isTLS {
-				multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
-				for i, b := range multiBuffer {
-					multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx)
-				}
+			multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
+			for i, b := range multiBuffer {
+				multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx)
 			}
 		}
 		if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil {
@@ -645,7 +643,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 			if statConn != nil {
 				counter = statConn.WriteCounter
 			}
-			err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter,
+			err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter,
 				&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 		} else {
 			// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer

+ 5 - 6
proxy/vless/outbound/outbound.go

@@ -243,10 +243,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 			if err1 == nil {
 				if requestAddons.Flow == vless.XRV {
 					encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
-					if isTLS {
-						for i, b := range multiBuffer {
-							multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx)
-						}
+					multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
+					for i, b := range multiBuffer {
+						multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx)
 					}
 				}
 				if err := serverWriter.WriteMultiBuffer(multiBuffer); err != nil {
@@ -256,7 +255,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 				return err1
 			} else if requestAddons.Flow == vless.XRV {
 				mb := make(buf.MultiBuffer, 1)
-				mb[0] = encoding.XtlsPadding(nil, 0x01, &userUUID, ctx) // it must not be tls so padding finish with it (command 1)
+				mb[0] = encoding.XtlsPadding(nil, encoding.CommandPaddingContinue, &userUUID, true, ctx) // we do a long padding to hide vless header
 				newError("Insert padding with empty content to camouflage VLESS header ", mb.Len()).WriteToLog(session.ExportIDToError(ctx))
 				if err := serverWriter.WriteMultiBuffer(mb); err != nil {
 					return err
@@ -285,7 +284,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 			if statConn != nil {
 				counter = statConn.WriteCounter
 			}
-			err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter,
+			err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter,
 				&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
 		} else {
 			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer