Browse Source

Quic sniffer (#1074)

* Add quic sniffer

* Fix quic sniffer

* Add uTP sniffer

* rename buf pool membership status to unmanaged

* rename buf type adaptor into FromBytes

Co-authored-by: 世界 <[email protected]>
Co-authored-by: Shelikhoo <[email protected]>
yuhan6665 3 years ago
parent
commit
3f64f3206c

+ 12 - 44
app/dispatcher/default.go

@@ -250,6 +250,9 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn
 
 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
 	domain := result.Domain()
+	if domain == "" {
+		return false
+	}
 	for _, d := range request.ExcludeForDomain {
 		if strings.ToLower(domain) == d {
 			return false
@@ -295,33 +298,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 
 	sniffingRequest := content.SniffingRequest
 	inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
-	switch {
-	case !sniffingRequest.Enabled:
-		go d.routedDispatch(ctx, outbound, destination)
-	case destination.Network != net.Network_TCP:
-		// Only metadata sniff will be used for non tcp connection
-		result, err := sniffer(ctx, nil, true)
-		if err == nil {
-			content.Protocol = result.Protocol()
-			if d.shouldOverride(ctx, result, sniffingRequest, destination) {
-				domain := result.Domain()
-				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
-				destination.Address = net.ParseAddress(domain)
-				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
-					ob.RouteTarget = destination
-				} else {
-					ob.Target = destination
-				}
-			}
-		}
+	if !sniffingRequest.Enabled {
 		go d.routedDispatch(ctx, outbound, destination)
-	default:
+	} else {
 		go func() {
 			cReader := &cachedReader{
 				reader: outbound.Reader.(*pipe.Reader),
 			}
 			outbound.Reader = cReader
-			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
+			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
 			if err == nil {
 				content.Protocol = result.Protocol()
 			}
@@ -356,33 +341,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		ctx = session.ContextWithContent(ctx, content)
 	}
 	sniffingRequest := content.SniffingRequest
-	switch {
-	case !sniffingRequest.Enabled:
+	if !sniffingRequest.Enabled {
 		go d.routedDispatch(ctx, outbound, destination)
-	case destination.Network != net.Network_TCP:
-		// Only metadata sniff will be used for non tcp connection
-		result, err := sniffer(ctx, nil, true)
-		if err == nil {
-			content.Protocol = result.Protocol()
-			if d.shouldOverride(ctx, result, sniffingRequest, destination) {
-				domain := result.Domain()
-				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
-				destination.Address = net.ParseAddress(domain)
-				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
-					ob.RouteTarget = destination
-				} else {
-					ob.Target = destination
-				}
-			}
-		}
-		go d.routedDispatch(ctx, outbound, destination)
-	default:
+	} else {
 		go func() {
 			cReader := &cachedReader{
 				reader: outbound.Reader.(*pipe.Reader),
 			}
 			outbound.Reader = cReader
-			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
+			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
 			if err == nil {
 				content.Protocol = result.Protocol()
 			}
@@ -399,10 +366,11 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 			d.routedDispatch(ctx, outbound, destination)
 		}()
 	}
+
 	return nil
 }
 
-func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) {
+func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
 	payload := buf.New()
 	defer payload.Release()
 
@@ -428,7 +396,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
 
 				cReader.Cache(payload)
 				if !payload.IsEmpty() {
-					result, err := sniffer.Sniff(ctx, payload.Bytes())
+					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
 					if err != common.ErrNoClue {
 						return result, err
 					}

+ 10 - 5
app/dispatcher/sniffer.go

@@ -2,6 +2,8 @@ package dispatcher
 
 import (
 	"context"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol/quic"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/protocol/bittorrent"
@@ -22,6 +24,7 @@ type protocolSnifferWithMetadata struct {
 	// for both TCP and UDP connections
 	// It will not be shown as a traffic type for routing unless there is no other successful sniffing.
 	metadataSniffer bool
+	network         net.Network
 }
 
 type Sniffer struct {
@@ -31,9 +34,11 @@ type Sniffer struct {
 func NewSniffer(ctx context.Context) *Sniffer {
 	ret := &Sniffer{
 		sniffer: []protocolSnifferWithMetadata{
-			{func(c context.Context, b []byte) (SniffResult, error) { return http.SniffHTTP(b) }, false},
-			{func(c context.Context, b []byte) (SniffResult, error) { return tls.SniffTLS(b) }, false},
-			{func(c context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffBittorrent(b) }, false},
+			{func(c context.Context, b []byte) (SniffResult, error) { return http.SniffHTTP(b) }, false, net.Network_TCP},
+			{func(c context.Context, b []byte) (SniffResult, error) { return tls.SniffTLS(b) }, false, net.Network_TCP},
+			{func(c context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffBittorrent(b) }, false, net.Network_TCP},
+			{func(c context.Context, b []byte) (SniffResult, error) { return quic.SniffQUIC(b) }, false, net.Network_UDP},
+			{func(c context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffUTP(b) }, false, net.Network_UDP},
 		},
 	}
 	if sniffer, err := newFakeDNSSniffer(ctx); err == nil {
@@ -49,11 +54,11 @@ func NewSniffer(ctx context.Context) *Sniffer {
 
 var errUnknownContent = newError("unknown content")
 
-func (s *Sniffer) Sniff(c context.Context, payload []byte) (SniffResult, error) {
+func (s *Sniffer) Sniff(c context.Context, payload []byte, network net.Network) (SniffResult, error) {
 	var pendingSniffer []protocolSnifferWithMetadata
 	for _, si := range s.sniffer {
 		s := si.protocolSniffer
-		if si.metadataSniffer {
+		if si.metadataSniffer || si.network != network {
 			continue
 		}
 		result, err := s(c, payload)

+ 38 - 5
common/buf/buffer.go

@@ -18,10 +18,11 @@ var pool = bytespool.GetPool(Size)
 // the buffer into an internal buffer pool, in order to recreate a buffer more
 // quickly.
 type Buffer struct {
-	v     []byte
-	start int32
-	end   int32
-	UDP   *net.Destination
+	v         []byte
+	start     int32
+	end       int32
+	unmanaged bool
+	UDP       *net.Destination
 }
 
 // New creates a Buffer with 0 length and 8K capacity.
@@ -38,6 +39,7 @@ func New() *Buffer {
 	}
 }
 
+// NewExisted creates a managed, standard size Buffer with an existed bytearray
 func NewExisted(b []byte) *Buffer {
 	if cap(b) < Size {
 		panic("Invalid buffer")
@@ -54,6 +56,15 @@ func NewExisted(b []byte) *Buffer {
 	}
 }
 
+// FromBytes creates a Buffer with an existed bytearray
+func FromBytes(b []byte) *Buffer {
+	return &Buffer{
+		v:         b,
+		end:       int32(len(b)),
+		unmanaged: true,
+	}
+}
+
 // StackNew creates a new Buffer object on stack.
 // This method is for buffers that is released in the same function.
 func StackNew() Buffer {
@@ -71,7 +82,7 @@ func StackNew() Buffer {
 
 // Release recycles the buffer into an internal buffer pool.
 func (b *Buffer) Release() {
-	if b == nil || b.v == nil {
+	if b == nil || b.v == nil || b.unmanaged {
 		return
 	}
 
@@ -212,6 +223,28 @@ func (b *Buffer) WriteString(s string) (int, error) {
 	return b.Write([]byte(s))
 }
 
+// ReadByte implements io.ByteReader
+func (b *Buffer) ReadByte() (byte, error) {
+	if b.start == b.end {
+		return 0, io.EOF
+	}
+
+	nb := b.v[b.start]
+	b.start++
+	return nb, nil
+}
+
+// ReadBytes implements bufio.Reader.ReadBytes
+func (b *Buffer) ReadBytes(length int32) ([]byte, error) {
+	if b.end-b.start < length {
+		return nil, io.EOF
+	}
+
+	nb := b.v[b.start : b.start+length]
+	b.start += length
+	return nb, nil
+}
+
 // Read implements io.Reader.Read().
 func (b *Buffer) Read(data []byte) (int, error) {
 	if b.Len() == 0 {

+ 59 - 0
common/protocol/bittorrent/bittorrent.go

@@ -1,8 +1,12 @@
 package bittorrent
 
 import (
+	"encoding/binary"
 	"errors"
+	"math"
+	"time"
 
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common"
 )
 
@@ -29,3 +33,58 @@ func SniffBittorrent(b []byte) (*SniffHeader, error) {
 
 	return nil, errNotBittorrent
 }
+
+func SniffUTP(b []byte) (*SniffHeader, error) {
+	if len(b) < 20 {
+		return nil, common.ErrNoClue
+	}
+
+	buffer := buf.FromBytes(b)
+
+	var typeAndVersion uint8
+
+	if binary.Read(buffer, binary.BigEndian, &typeAndVersion) != nil {
+		return nil, common.ErrNoClue
+	} else if b[0]>>4&0xF > 4 || b[0]&0xF != 1 {
+		return nil, errNotBittorrent
+	}
+
+	var extension uint8
+
+	if binary.Read(buffer, binary.BigEndian, &extension) != nil {
+		return nil, common.ErrNoClue
+	} else if extension != 0 && extension != 1 {
+		return nil, errNotBittorrent
+	}
+
+	for extension != 0 {
+		if extension != 1 {
+			return nil, errNotBittorrent
+		}
+		if binary.Read(buffer, binary.BigEndian, &extension) != nil {
+			return nil, common.ErrNoClue
+		}
+
+		var length uint8
+		if err := binary.Read(buffer, binary.BigEndian, &length); err != nil {
+			return nil, common.ErrNoClue
+		}
+		if common.Error2(buffer.ReadBytes(int32(length))) != nil {
+			return nil, common.ErrNoClue
+		}
+	}
+
+	if common.Error2(buffer.ReadBytes(2)) != nil {
+		return nil, common.ErrNoClue
+	}
+
+	var timestamp uint32
+	if err := binary.Read(buffer, binary.BigEndian, &timestamp); err != nil {
+		return nil, common.ErrNoClue
+	}
+	if math.Abs(float64(time.Now().UnixMicro()-int64(timestamp))) > float64(24*time.Hour) {
+		return nil, errNotBittorrent
+	}
+
+	return &SniffHeader{}, nil
+}

+ 19 - 0
common/protocol/quic/qtls_go116.go

@@ -0,0 +1,19 @@
+//go:build go1.16 && !go1.17
+// +build go1.16,!go1.17
+
+package quic
+
+import (
+	"crypto/cipher"
+
+	"github.com/marten-seemann/qtls-go1-16"
+)
+
+type (
+	// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
+	CipherSuiteTLS13 = qtls.CipherSuiteTLS13
+)
+
+func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
+	return qtls.AEADAESGCMTLS13(key, fixedNonce)
+}

+ 19 - 0
common/protocol/quic/qtls_go117.go

@@ -0,0 +1,19 @@
+//go:build go1.17 && !go1.18
+// +build go1.17,!go1.18
+
+package quic
+
+import (
+	"crypto/cipher"
+
+	"github.com/marten-seemann/qtls-go1-17"
+)
+
+type (
+	// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
+	CipherSuiteTLS13 = qtls.CipherSuiteTLS13
+)
+
+func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
+	return qtls.AEADAESGCMTLS13(key, fixedNonce)
+}

+ 19 - 0
common/protocol/quic/qtls_go118.go

@@ -0,0 +1,19 @@
+//go:build go1.18
+// +build go1.18
+
+package quic
+
+import (
+	"crypto/cipher"
+
+	"github.com/marten-seemann/qtls-go1-18"
+)
+
+type (
+	// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
+	CipherSuiteTLS13 = qtls.CipherSuiteTLS13
+)
+
+func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
+	return qtls.AEADAESGCMTLS13(key, fixedNonce)
+}

+ 208 - 0
common/protocol/quic/sniff.go

@@ -0,0 +1,208 @@
+package quic
+
+import (
+	"crypto"
+	"crypto/aes"
+	"crypto/tls"
+	"encoding/binary"
+	"io"
+
+	"github.com/lucas-clemente/quic-go/quicvarint"
+	"golang.org/x/crypto/hkdf"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/errors"
+	ptls "github.com/xtls/xray-core/common/protocol/tls"
+)
+
+type SniffHeader struct {
+	domain string
+}
+
+func (s SniffHeader) Protocol() string {
+	return "quic"
+}
+
+func (s SniffHeader) Domain() string {
+	return s.domain
+}
+
+const (
+	versionDraft29 uint32 = 0xff00001d
+	version1       uint32 = 0x1
+)
+
+var (
+	quicSaltOld  = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
+	quicSalt     = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
+	initialSuite = &CipherSuiteTLS13{
+		ID:     tls.TLS_AES_128_GCM_SHA256,
+		KeyLen: 16,
+		AEAD:   AEADAESGCMTLS13,
+		Hash:   crypto.SHA256,
+	}
+	errNotQuic        = errors.New("not quic")
+	errNotQuicInitial = errors.New("not initial packet")
+)
+
+func SniffQUIC(b []byte) (*SniffHeader, error) {
+	buffer := buf.FromBytes(b)
+	typeByte, err := buffer.ReadByte()
+	if err != nil {
+		return nil, errNotQuic
+	}
+	isLongHeader := typeByte&0x80 > 0
+	if !isLongHeader || typeByte&0x40 == 0 {
+		return nil, errNotQuicInitial
+	}
+
+	vb, err := buffer.ReadBytes(4)
+	if err != nil {
+		return nil, errNotQuic
+	}
+
+	versionNumber := binary.BigEndian.Uint32(vb)
+
+	if versionNumber != 0 && typeByte&0x40 == 0 {
+		return nil, errNotQuic
+	} else if versionNumber != versionDraft29 && versionNumber != version1 {
+		return nil, errNotQuic
+	}
+
+	if (typeByte&0x30)>>4 != 0x0 {
+		return nil, errNotQuicInitial
+	}
+
+	var destConnID []byte
+	if l, err := buffer.ReadByte(); err != nil {
+		return nil, errNotQuic
+	} else if destConnID, err = buffer.ReadBytes(int32(l)); err != nil {
+		return nil, errNotQuic
+	}
+
+	if l, err := buffer.ReadByte(); err != nil {
+		return nil, errNotQuic
+	} else if common.Error2(buffer.ReadBytes(int32(l))) != nil {
+		return nil, errNotQuic
+	}
+
+	tokenLen, err := quicvarint.Read(buffer)
+	if err != nil || tokenLen > uint64(len(b)) {
+		return nil, errNotQuic
+	}
+
+	if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
+		return nil, errNotQuic
+	}
+
+	packetLen, err := quicvarint.Read(buffer)
+	if err != nil {
+		return nil, errNotQuic
+	}
+
+	hdrLen := len(b) - int(buffer.Len())
+
+	origPNBytes := make([]byte, 4)
+	copy(origPNBytes, b[hdrLen:hdrLen+4])
+
+	var salt []byte
+	if versionNumber == version1 {
+		salt = quicSalt
+	} else {
+		salt = quicSaltOld
+	}
+	initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
+	secret := hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
+	hpKey := hkdfExpandLabel(initialSuite.Hash, secret, []byte{}, "quic hp", initialSuite.KeyLen)
+	block, err := aes.NewCipher(hpKey)
+	if err != nil {
+		return nil, err
+	}
+
+	cache := buf.New()
+	defer cache.Release()
+
+	mask := cache.Extend(int32(block.BlockSize()))
+	block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
+	b[0] ^= mask[0] & 0xf
+	for i := range b[hdrLen : hdrLen+4] {
+		b[hdrLen+i] ^= mask[i+1]
+	}
+	packetNumberLength := b[0]&0x3 + 1
+	if packetNumberLength != 1 {
+		return nil, errNotQuicInitial
+	}
+	var packetNumber uint32
+	{
+		n, err := buffer.ReadByte()
+		if err != nil {
+			return nil, err
+		}
+		packetNumber = uint32(n)
+	}
+
+	if packetNumber != 0 {
+		return nil, errNotQuicInitial
+	}
+
+	extHdrLen := hdrLen + int(packetNumberLength)
+	copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
+	data := b[extHdrLen : int(packetLen)+hdrLen]
+
+	key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
+	iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
+	cipher := AEADAESGCMTLS13(key, iv)
+	nonce := cache.Extend(int32(cipher.NonceSize()))
+	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
+	decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen])
+	if err != nil {
+		return nil, err
+	}
+	buffer = buf.FromBytes(decrypted)
+	frameType, err := buffer.ReadByte()
+	if err != nil {
+		return nil, io.ErrUnexpectedEOF
+	}
+	if frameType != 0x6 {
+		// not crypto frame
+		return &SniffHeader{domain: ""}, nil
+	}
+	if common.Error2(quicvarint.Read(buffer)) != nil {
+		return nil, io.ErrUnexpectedEOF
+	}
+	dataLen, err := quicvarint.Read(buffer)
+	if err != nil {
+		return nil, io.ErrUnexpectedEOF
+	}
+	if dataLen > uint64(buffer.Len()) {
+		return nil, io.ErrUnexpectedEOF
+	}
+	frameData, err := buffer.ReadBytes(int32(dataLen))
+	common.Must(err)
+	tlsHdr := &ptls.SniffHeader{}
+	err = ptls.ReadClientHello(frameData, tlsHdr)
+	if err != nil {
+		return nil, err
+	}
+
+	return &SniffHeader{domain: tlsHdr.Domain()}, nil
+}
+
+func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
+	b := make([]byte, 3, 3+6+len(label)+1+len(context))
+	binary.BigEndian.PutUint16(b, uint16(length))
+	b[2] = uint8(6 + len(label))
+	b = append(b, []byte("tls13 ")...)
+	b = append(b, []byte(label)...)
+	b = b[:3+6+len(label)+1]
+	b[3+6+len(label)] = uint8(len(context))
+	b = append(b, context...)
+
+	out := make([]byte, length)
+	n, err := hkdf.Expand(hash.New, secret, b).Read(out)
+	if err != nil || n != length {
+		panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
+	}
+	return out
+}

+ 18 - 0
common/protocol/quic/sniff_test.go

@@ -0,0 +1,18 @@
+package quic_test
+
+import (
+	"encoding/hex"
+	"testing"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/protocol/quic"
+)
+
+func TestSniffQUIC(t *testing.T) {
+	pkt, err := hex.DecodeString("cd0000000108f1fb7bcc78aa5e7203a8f86400421531fe825b19541876db6c55c38890cd73149d267a084afee6087304095417a3033df6a81bbb71d8512e7a3e16df1e277cae5df3182cb214b8fe982ba3fdffbaa9ffec474547d55945f0fddbeadfb0b5243890b2fa3da45169e2bd34ec04b2e29382f48d612b28432a559757504d158e9e505407a77dd34f4b60b8d3b555ee85aacd6648686802f4de25e7216b19e54c5f78e8a5963380c742d861306db4c16e4f7fc94957aa50b9578a0b61f1e406b2ad5f0cd3cd271c4d99476409797b0c3cb3efec256118912d4b7e4fd79d9cb9016b6e5eaa4f5e57b637b217755daf8968a4092bed0ed5413f5d04904b3a61e4064f9211b2629e5b52a89c7b19f37a713e41e27743ea6dfa736dfa1bb0a4b2bc8c8dc632c6ce963493a20c550e6fdb2475213665e9a85cfc394da9cec0cf41f0c8abed3fc83be5245b2b5aa5e825d29349f721d30774ef5bf965b540f3d8d98febe20956b1fc8fa047e10e7d2f921c9c6622389e02322e80621a1cf5264e245b7276966eb02932584e3f7038bd36aa908766ad3fb98344025dec18670d6db43a1c5daac00937fce7b7c7d61ff4e6efd01a2bdee0ee183108b926393df4f3d74bbcbb015f240e7e346b7d01c41111a401225ce3b095ab4623a5836169bf9599eeca79d1d2e9b2202b5960a09211e978058d6fc0484eff3e91ce4649a5e3ba15b906d334cf66e28d9ff575406e1ae1ac2febafd72870b6f5d58fc5fb949cb1f40feb7c1d9ce5e71b")
+	common.Must(err)
+	quicHdr, err := quic.SniffQUIC(pkt)
+	if err != nil || quicHdr.Domain() != "www.google.com" {
+		t.Error("failed")
+	}
+}

+ 2 - 0
infra/conf/xray.go

@@ -78,6 +78,8 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
 				p = append(p, "http")
 			case "tls", "https", "ssl":
 				p = append(p, "tls")
+			case "quic":
+				p = append(p, "quic")
 			case "fakedns":
 				p = append(p, "fakedns")
 			case "fakedns+others":