Browse Source

QUIC sniffer: Full support for handling multiple initial packets (#4642)

Co-authored-by: RPRX <[email protected]>
Co-authored-by: Vigilans <[email protected]>
Co-authored-by: Shelikhoo <[email protected]>
Co-authored-by: dyhkwong <[email protected]>
j2rong4cn 6 months ago
parent
commit
58c48664e2

+ 24 - 19
app/dispatcher/default.go

@@ -33,23 +33,21 @@ type cachedReader struct {
 	cache  buf.MultiBuffer
 }
 
-func (r *cachedReader) Cache(b *buf.Buffer) {
-	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
+func (r *cachedReader) Cache(b *buf.Buffer, deadline time.Duration) error {
+	mb, err := r.reader.ReadMultiBufferTimeout(deadline)
+	if err != nil {
+		return err
+	}
 	r.Lock()
 	if !mb.IsEmpty() {
 		r.cache, _ = buf.MergeMulti(r.cache, mb)
 	}
-	cacheLen := r.cache.Len()
-	if cacheLen <= b.Cap() {
-		b.Clear()
-	} else {
-		b.Release()
-		*b = *buf.NewWithSize(cacheLen)
-	}
-	rawBytes := b.Extend(cacheLen)
+	b.Clear()
+	rawBytes := b.Extend(b.Cap())
 	n := r.cache.Copy(rawBytes)
 	b.Resize(0, int32(n))
 	r.Unlock()
+	return nil
 }
 
 func (r *cachedReader) readInternal() buf.MultiBuffer {
@@ -355,7 +353,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 }
 
 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
-	payload := buf.New()
+	payload := buf.NewWithSize(32767)
 	defer payload.Release()
 
 	sniffer := NewSniffer(ctx)
@@ -367,26 +365,33 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
 	}
 
 	contentResult, contentErr := func() (SniffResult, error) {
+		cacheDeadline := 200 * time.Millisecond
 		totalAttempt := 0
 		for {
 			select {
 			case <-ctx.Done():
 				return nil, ctx.Err()
 			default:
-				totalAttempt++
-				if totalAttempt > 2 {
-					return nil, errSniffingTimeout
-				}
+				cachingStartingTimeStamp := time.Now()
+				cacheErr := cReader.Cache(payload, cacheDeadline)
+				cachingTimeElapsed := time.Since(cachingStartingTimeStamp)
+				cacheDeadline -= cachingTimeElapsed
 
-				cReader.Cache(payload)
 				if !payload.IsEmpty() {
 					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
-					if err != common.ErrNoClue {
+					switch err {
+					case common.ErrNoClue: // No Clue: protocol not matches, and sniffer cannot determine whether there will be a match or not
+						totalAttempt++
+					case protocol.ErrProtoNeedMoreData: // Protocol Need More Data: protocol matches, but need more data to complete sniffing
+						if cacheErr != nil { // Cache error (e.g. timeout) counts for failed attempt
+							totalAttempt++
+						}
+					default:
 						return result, err
 					}
 				}
-				if payload.IsFull() {
-					return nil, errUnknownContent
+				if totalAttempt >= 2 || cacheDeadline <= 0 {
+					return nil, errSniffingTimeout
 				}
 			}
 		}

+ 6 - 2
app/dispatcher/sniffer.go

@@ -6,6 +6,7 @@ import (
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/protocol/bittorrent"
 	"github.com/xtls/xray-core/common/protocol/http"
 	"github.com/xtls/xray-core/common/protocol/quic"
@@ -58,14 +59,17 @@ var errUnknownContent = errors.New("unknown content")
 func (s *Sniffer) Sniff(c context.Context, payload []byte, network net.Network) (SniffResult, error) {
 	var pendingSniffer []protocolSnifferWithMetadata
 	for _, si := range s.sniffer {
-		s := si.protocolSniffer
+		protocolSniffer := si.protocolSniffer
 		if si.metadataSniffer || si.network != network {
 			continue
 		}
-		result, err := s(c, payload)
+		result, err := protocolSniffer(c, payload)
 		if err == common.ErrNoClue {
 			pendingSniffer = append(pendingSniffer, si)
 			continue
+		} else if err == protocol.ErrProtoNeedMoreData { // Sniffer protocol matched, but need more data to complete sniffing
+			s.sniffer = []protocolSnifferWithMetadata{si}
+			return nil, err
 		}
 
 		if err == nil && result != nil {

+ 31 - 16
common/buf/buffer.go

@@ -15,6 +15,15 @@ const (
 
 var pool = bytespool.GetPool(Size)
 
+// ownership represents the data owner of the buffer.
+type ownership uint8
+
+const (
+	managed ownership = iota
+	unmanaged
+	bytespools
+)
+
 // Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
 // the buffer into an internal buffer pool, in order to recreate a buffer more
 // quickly.
@@ -22,11 +31,11 @@ type Buffer struct {
 	v         []byte
 	start     int32
 	end       int32
-	unmanaged bool
+	ownership ownership
 	UDP       *net.Destination
 }
 
-// New creates a Buffer with 0 length and 8K capacity.
+// New creates a Buffer with 0 length and 8K capacity, managed.
 func New() *Buffer {
 	buf := pool.Get().([]byte)
 	if cap(buf) >= Size {
@@ -40,7 +49,7 @@ func New() *Buffer {
 	}
 }
 
-// NewExisted creates a managed, standard size Buffer with an existed bytearray
+// NewExisted creates a standard size Buffer with an existed bytearray, managed.
 func NewExisted(b []byte) *Buffer {
 	if cap(b) < Size {
 		panic("Invalid buffer")
@@ -57,16 +66,16 @@ func NewExisted(b []byte) *Buffer {
 	}
 }
 
-// FromBytes creates a Buffer with an existed bytearray
+// FromBytes creates a Buffer with an existed bytearray, unmanaged.
 func FromBytes(b []byte) *Buffer {
 	return &Buffer{
 		v:         b,
 		end:       int32(len(b)),
-		unmanaged: true,
+		ownership: unmanaged,
 	}
 }
 
-// StackNew creates a new Buffer object on stack.
+// StackNew creates a new Buffer object on stack, managed.
 // This method is for buffers that is released in the same function.
 func StackNew() Buffer {
 	buf := pool.Get().([]byte)
@@ -81,9 +90,17 @@ func StackNew() Buffer {
 	}
 }
 
+// NewWithSize creates a Buffer with 0 length and capacity with at least the given size, bytespool's.
+func NewWithSize(size int32) *Buffer {
+	return &Buffer{
+		v:         bytespool.Alloc(size),
+		ownership: bytespools,
+	}
+}
+
 // Release recycles the buffer into an internal buffer pool.
 func (b *Buffer) Release() {
-	if b == nil || b.v == nil || b.unmanaged {
+	if b == nil || b.v == nil || b.ownership == unmanaged {
 		return
 	}
 
@@ -91,8 +108,13 @@ func (b *Buffer) Release() {
 	b.v = nil
 	b.Clear()
 
-	if cap(p) == Size {
-		pool.Put(p)
+	switch b.ownership {
+	case managed:
+		if cap(p) == Size {
+			pool.Put(p)
+		}
+	case bytespools:
+		bytespool.Free(p)
 	}
 	b.UDP = nil
 }
@@ -215,13 +237,6 @@ func (b *Buffer) Cap() int32 {
 	return int32(len(b.v))
 }
 
-// NewWithSize creates a Buffer with 0 length and capacity with at least the given size.
-func NewWithSize(size int32) *Buffer {
-	return &Buffer{
-		v: bytespool.Alloc(size),
-	}
-}
-
 // IsEmpty returns true if the buffer is empty.
 func (b *Buffer) IsEmpty() bool {
 	return b.Len() == 0

+ 6 - 0
common/protocol/protocol.go

@@ -1 +1,7 @@
 package protocol // import "github.com/xtls/xray-core/common/protocol"
+
+import (
+	"errors"
+)
+
+var ErrProtoNeedMoreData = errors.New("protocol matches, but need more data to complete sniffing")

+ 42 - 52
common/protocol/quic/sniff.go

@@ -1,7 +1,6 @@
 package quic
 
 import (
-	"context"
 	"crypto"
 	"crypto/aes"
 	"crypto/tls"
@@ -13,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/bytespool"
 	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/protocol"
 	ptls "github.com/xtls/xray-core/common/protocol/tls"
 	"golang.org/x/crypto/hkdf"
 )
@@ -47,22 +47,17 @@ var (
 	errNotQuicInitial = errors.New("not initial packet")
 )
 
-func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
-	// In extremely rare cases, this sniffer may cause slice error
-	// and we set recover() here to prevent crash.
-	// TODO: Thoroughly fix this panic
-	defer func() {
-		if r := recover(); r != nil {
-			errors.LogError(context.Background(), "Failed to sniff QUIC: ", r)
-			resultReturn = nil
-			errorReturn = common.ErrNoClue
-		}
-	}()
+func SniffQUIC(b []byte) (*SniffHeader, error) {
+	if len(b) == 0 {
+		return nil, common.ErrNoClue
+	}
 
 	// Crypto data separated across packets
 	cryptoLen := 0
-	cryptoData := bytespool.Alloc(int32(len(b)))
+	cryptoData := bytespool.Alloc(32767)
 	defer bytespool.Free(cryptoData)
+	cache := buf.New()
+	defer cache.Release()
 
 	// Parse QUIC packets
 	for len(b) > 0 {
@@ -105,13 +100,15 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
 			return nil, errNotQuic
 		}
 
-		tokenLen, err := quicvarint.Read(buffer)
-		if err != nil || tokenLen > uint64(len(b)) {
-			return nil, errNotQuic
-		}
+		if isQuicInitial { // Only initial packets have token, see https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.2
+			tokenLen, err := quicvarint.Read(buffer)
+			if err != nil || tokenLen > uint64(len(b)) {
+				return nil, errNotQuic
+			}
 
-		if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
-			return nil, errNotQuic
+			if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
+				return nil, errNotQuic
+			}
 		}
 
 		packetLen, err := quicvarint.Read(buffer)
@@ -130,9 +127,6 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
 			continue
 		}
 
-		origPNBytes := make([]byte, 4)
-		copy(origPNBytes, b[hdrLen:hdrLen+4])
-
 		var salt []byte
 		if versionNumber == version1 {
 			salt = quicSalt
@@ -147,44 +141,34 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
 			return nil, err
 		}
 
-		cache := buf.New()
-		defer cache.Release()
-
+		cache.Clear()
 		mask := cache.Extend(int32(block.BlockSize()))
 		block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
 		b[0] ^= mask[0] & 0xf
-		for i := range b[hdrLen : hdrLen+4] {
+		packetNumberLength := int(b[0]&0x3 + 1)
+		for i := range packetNumberLength {
 			b[hdrLen+i] ^= mask[i+1]
 		}
-		packetNumberLength := b[0]&0x3 + 1
-		if packetNumberLength != 1 {
-			return nil, errNotQuicInitial
-		}
-		var packetNumber uint32
-		{
-			n, err := buffer.ReadByte()
-			if err != nil {
-				return nil, err
-			}
-			packetNumber = uint32(n)
-		}
-
-		extHdrLen := hdrLen + int(packetNumberLength)
-		copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
-		data := b[extHdrLen : int(packetLen)+hdrLen]
 
 		key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
 		iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
 		cipher := AEADAESGCMTLS13(key, iv)
+
 		nonce := cache.Extend(int32(cipher.NonceSize()))
-		binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
+		_, err = buffer.Read(nonce[len(nonce)-packetNumberLength:])
+		if err != nil {
+			return nil, err
+		}
+
+		extHdrLen := hdrLen + packetNumberLength
+		data := b[extHdrLen : int(packetLen)+hdrLen]
 		decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen])
 		if err != nil {
 			return nil, err
 		}
 		buffer = buf.FromBytes(decrypted)
-		for i := 0; !buffer.IsEmpty(); i++ {
-			frameType := byte(0x0) // Default to PADDING frame
+		for !buffer.IsEmpty() {
+			frameType, _ := buffer.ReadByte()
 			for frameType == 0x0 && !buffer.IsEmpty() {
 				frameType, _ = buffer.ReadByte()
 			}
@@ -234,13 +218,12 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
 					return nil, io.ErrUnexpectedEOF
 				}
 				if cryptoLen < int(offset+length) {
-					cryptoLen = int(offset + length)
-					if len(cryptoData) < cryptoLen {
-						newCryptoData := bytespool.Alloc(int32(cryptoLen))
-						copy(newCryptoData, cryptoData)
-						bytespool.Free(cryptoData)
-						cryptoData = newCryptoData
+					newCryptoLen := int(offset + length)
+					if len(cryptoData) < newCryptoLen {
+						return nil, io.ErrShortBuffer
 					}
+					wipeBytes(cryptoData[cryptoLen:newCryptoLen])
+					cryptoLen = newCryptoLen
 				}
 				if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data
 					return nil, io.ErrUnexpectedEOF
@@ -276,7 +259,14 @@ func SniffQUIC(b []byte) (resultReturn *SniffHeader, errorReturn error) {
 		}
 		return &SniffHeader{domain: tlsHdr.Domain()}, nil
 	}
-	return nil, common.ErrNoClue
+	// All payload is parsed as valid QUIC packets, but we need more packets for crypto data to read client hello.
+	return nil, protocol.ErrProtoNeedMoreData
+}
+
+func wipeBytes(b []byte) {
+	for i := range len(b) {
+		b[i] = 0x0
+	}
 }
 
 func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {

File diff suppressed because it is too large
+ 83 - 0
common/protocol/quic/sniff_test.go


+ 11 - 6
common/protocol/tls/sniff.go

@@ -3,9 +3,9 @@ package tls
 import (
 	"encoding/binary"
 	"errors"
-	"strings"
 
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/protocol"
 )
 
 type SniffHeader struct {
@@ -59,9 +59,6 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
 	}
 	data = data[1+compressionMethodsLen:]
 
-	if len(data) == 0 {
-		return errNotClientHello
-	}
 	if len(data) < 2 {
 		return errNotClientHello
 	}
@@ -104,13 +101,21 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
 					return errNotClientHello
 				}
 				if nameType == 0 {
-					serverName := string(d[:nameLen])
+					// QUIC separated across packets
+					// May cause the serverName to be incomplete
+					b := byte(0)
+					for _, b = range d[:nameLen] {
+						if b <= ' ' {
+							return protocol.ErrProtoNeedMoreData
+						}
+					}
 					// An SNI value may not include a
 					// trailing dot. See
 					// https://tools.ietf.org/html/rfc6066#section-3.
-					if strings.HasSuffix(serverName, ".") {
+					if b == '.' {
 						return errNotClientHello
 					}
+					serverName := string(d[:nameLen])
 					h.domain = serverName
 					return nil
 				}

Some files were not shown because too many files changed in this diff