Browse Source

Fix sniff fragmented quic client hello

世界 3 years ago
parent
commit
3f1fe814ef
4 changed files with 80 additions and 42 deletions
  1. 63 28
      common/sniff/quic.go
  2. 2 0
      common/sniff/quic_test.go
  3. 11 10
      common/sniff/sniff.go
  4. 4 4
      route/router.go

+ 63 - 28
common/sniff/quic.go

@@ -24,8 +24,7 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex
 	if err != nil {
 		return nil, err
 	}
-
-	if typeByte&0x80 == 0 || typeByte&0x40 == 0 {
+	if typeByte&0x40 == 0 {
 		return nil, E.New("bad type byte")
 	}
 	var versionNumber uint32
@@ -145,9 +144,6 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex
 	default:
 		return nil, E.New("bad packet number length")
 	}
-	if packetNumber != 0 {
-		return nil, E.New("bad packet number: ", packetNumber)
-	}
 	extHdrLen := hdrLen + int(packetNumberLength)
 	copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
 	data := newPacket[extHdrLen : int(packetLen)+hdrLen]
@@ -172,37 +168,76 @@ func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContex
 	if err != nil {
 		return nil, err
 	}
-	decryptedReader := bytes.NewReader(decrypted)
-	frameType, err := decryptedReader.ReadByte()
-	if err != nil {
-		return nil, err
+	var frameType byte
+	var frameLen uint64
+	var fragments []struct {
+		offset  uint64
+		length  uint64
+		payload []byte
 	}
-	for frameType == 0x0 {
-		// skip padding
+	decryptedReader := bytes.NewReader(decrypted)
+	for {
 		frameType, err = decryptedReader.ReadByte()
-		if err != nil {
-			return nil, err
+		if err == io.EOF {
+			break
+		}
+		switch frameType {
+		case 0x0:
+			continue
+		case 0x1:
+			continue
+		case 0x6:
+			var offset uint64
+			offset, err = qtls.ReadUvarint(decryptedReader)
+			if err != nil {
+				return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
+			}
+			var length uint64
+			length, err = qtls.ReadUvarint(decryptedReader)
+			if err != nil {
+				return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
+			}
+			index := len(decrypted) - decryptedReader.Len()
+			fragments = append(fragments, struct {
+				offset  uint64
+				length  uint64
+				payload []byte
+			}{offset, length, decrypted[index : index+int(length)]})
+			frameLen += length
+			_, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
+			if err != nil {
+				return nil, err
+			}
+		default:
+			// ignore unknown frame type
 		}
-	}
-	if frameType != 0x6 {
-		// not crypto frame
-		return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, nil
-	}
-	_, err = qtls.ReadUvarint(decryptedReader)
-	if err != nil {
-		return nil, err
-	}
-	_, err = qtls.ReadUvarint(decryptedReader)
-	if err != nil {
-		return nil, err
 	}
 	tlsHdr := make([]byte, 5)
 	tlsHdr[0] = 0x16
 	binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303))
-	binary.BigEndian.PutUint16(tlsHdr[3:], uint16(decryptedReader.Len()))
-	metadata, err := TLSClientHello(ctx, io.MultiReader(bytes.NewReader(tlsHdr), decryptedReader))
+	binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen))
+	var index uint64
+	var length int
+	var readers []io.Reader
+	readers = append(readers, bytes.NewReader(tlsHdr))
+find:
+	for {
+		for _, fragment := range fragments {
+			if fragment.offset == index {
+				readers = append(readers, bytes.NewReader(fragment.payload))
+				index = fragment.offset + fragment.length
+				length++
+				continue find
+			}
+		}
+		if length == len(fragments) {
+			break
+		}
+		return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments")
+	}
+	metadata, err := TLSClientHello(ctx, io.MultiReader(readers...))
 	if err != nil {
-		return nil, err
+		return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
 	}
 	metadata.Protocol = C.ProtocolQUIC
 	return metadata, nil

File diff suppressed because it is too large
+ 2 - 0
common/sniff/quic_test.go


+ 11 - 10
common/sniff/sniff.go

@@ -5,7 +5,6 @@ import (
 	"context"
 	"io"
 	"net"
-	"os"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
@@ -33,23 +32,25 @@ func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, timeout
 		return nil, err
 	}
 	var metadata *adapter.InboundContext
+	var errors []error
 	for _, sniffer := range sniffers {
 		metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes()))
-		if err != nil {
-			continue
+		if metadata != nil {
+			return metadata, nil
 		}
-		return metadata, nil
+		errors = append(errors, err)
 	}
-	return nil, os.ErrInvalid
+	return nil, E.Errors(errors...)
 }
 
 func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) (*adapter.InboundContext, error) {
+	var errors []error
 	for _, sniffer := range sniffers {
-		sniffMetadata, err := sniffer(ctx, packet)
-		if err != nil {
-			continue
+		metadata, err := sniffer(ctx, packet)
+		if metadata != nil {
+			return metadata, nil
 		}
-		return sniffMetadata, nil
+		errors = append(errors, err)
 	}
-	return nil, os.ErrInvalid
+	return nil, E.Errors(errors...)
 }

+ 4 - 4
route/router.go

@@ -554,8 +554,8 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 	if metadata.InboundOptions.SniffEnabled {
 		buffer := buf.NewPacket()
 		buffer.FullReset()
-		sniffMetadata, err := sniff.PeekStream(ctx, conn, buffer, time.Duration(metadata.InboundOptions.SniffTimeout), sniff.StreamDomainNameQuery, sniff.TLSClientHello, sniff.HTTPHost)
-		if err == nil {
+		sniffMetadata, _ := sniff.PeekStream(ctx, conn, buffer, time.Duration(metadata.InboundOptions.SniffTimeout), sniff.StreamDomainNameQuery, sniff.TLSClientHello, sniff.HTTPHost)
+		if sniffMetadata != nil {
 			metadata.Protocol = sniffMetadata.Protocol
 			metadata.Domain = sniffMetadata.Domain
 			if metadata.InboundOptions.SniffOverrideDestination && M.IsDomainName(metadata.Domain) {
@@ -636,8 +636,8 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 			buffer.Release()
 			return err
 		}
-		sniffMetadata, err := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.DomainNameQuery, sniff.QUICClientHello, sniff.STUNMessage)
-		if err == nil {
+		sniffMetadata, _ := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.DomainNameQuery, sniff.QUICClientHello, sniff.STUNMessage)
+		if sniffMetadata != nil {
 			metadata.Protocol = sniffMetadata.Protocol
 			metadata.Domain = sniffMetadata.Domain
 			if metadata.InboundOptions.SniffOverrideDestination && M.IsDomainName(metadata.Domain) {

Some files were not shown because too many files changed in this diff