瀏覽代碼

Fix sniffer errors override each others

* Fix sniffer errors override each others

* Do not return ErrNeedMoreData if header is not expected
dyhkwong 6 月之前
父節點
當前提交
f2bbf6b2aa
共有 7 個文件被更改,包括 102 次插入10 次删除
  1. 8 3
      common/sniff/bittorrent.go
  2. 21 0
      common/sniff/bittorrent_test.go
  3. 18 4
      common/sniff/dns.go
  4. 30 0
      common/sniff/dns_test.go
  5. 1 1
      common/sniff/sniff.go
  6. 3 2
      common/sniff/ssh.go
  7. 21 0
      common/sniff/ssh_test.go

+ 8 - 3
common/sniff/bittorrent.go

@@ -31,13 +31,18 @@ func BitTorrent(_ context.Context, metadata *adapter.InboundContext, reader io.R
 		return os.ErrInvalid
 	}
 
+	const header = "BitTorrent protocol"
 	var protocol [19]byte
-	_, err = reader.Read(protocol[:])
+	var n int
+	n, err = reader.Read(protocol[:])
+	if string(protocol[:n]) != header[:n] {
+		return os.ErrInvalid
+	}
 	if err != nil {
 		return E.Cause1(ErrNeedMoreData, err)
 	}
-	if string(protocol[:]) != "BitTorrent protocol" {
-		return os.ErrInvalid
+	if n < 19 {
+		return ErrNeedMoreData
 	}
 
 	metadata.Protocol = C.ProtocolBitTorrent

+ 21 - 0
common/sniff/bittorrent_test.go

@@ -32,6 +32,27 @@ func TestSniffBittorrent(t *testing.T) {
 	}
 }
 
+func TestSniffIncompleteBittorrent(t *testing.T) {
+	t.Parallel()
+
+	pkt, err := hex.DecodeString("13426974546f7272656e74")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt))
+	require.ErrorIs(t, err, sniff.ErrNeedMoreData)
+}
+
+func TestSniffNotBittorrent(t *testing.T) {
+	t.Parallel()
+
+	pkt, err := hex.DecodeString("13426974546f7272656e75")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt))
+	require.NotEmpty(t, err)
+	require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
+}
+
 func TestSniffUTP(t *testing.T) {
 	t.Parallel()
 

+ 18 - 4
common/sniff/dns.go

@@ -20,22 +20,36 @@ func StreamDomainNameQuery(readCtx context.Context, metadata *adapter.InboundCon
 	if err != nil {
 		return E.Cause1(ErrNeedMoreData, err)
 	}
-	if length == 0 {
+	if length < 12 {
 		return os.ErrInvalid
 	}
 	buffer := buf.NewSize(int(length))
 	defer buffer.Release()
-	_, err = buffer.ReadFullFrom(reader, buffer.FreeLen())
+	var n int
+	n, err = buffer.ReadFullFrom(reader, buffer.FreeLen())
+	packet := buffer.Bytes()
+	if n > 2 && packet[2]&0x80 != 0 { // QR
+		return os.ErrInvalid
+	}
+	if n > 5 && packet[4] == 0 && packet[5] == 0 { // QDCOUNT
+		return os.ErrInvalid
+	}
+	for i := 6; i < 10; i++ {
+		// ANCOUNT, NSCOUNT
+		if n > i && packet[i] != 0 {
+			return os.ErrInvalid
+		}
+	}
 	if err != nil {
 		return E.Cause1(ErrNeedMoreData, err)
 	}
-	return DomainNameQuery(readCtx, metadata, buffer.Bytes())
+	return DomainNameQuery(readCtx, metadata, packet)
 }
 
 func DomainNameQuery(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
 	var msg mDNS.Msg
 	err := msg.Unpack(packet)
-	if err != nil {
+	if err != nil || msg.Response || len(msg.Question) == 0 || len(msg.Answer) > 0 || len(msg.Ns) > 0 {
 		return err
 	}
 	metadata.Protocol = C.ProtocolDNS

+ 30 - 0
common/sniff/dns_test.go

@@ -1,6 +1,7 @@
 package sniff_test
 
 import (
+	"bytes"
 	"context"
 	"encoding/hex"
 	"testing"
@@ -21,3 +22,32 @@ func TestSniffDNS(t *testing.T) {
 	require.NoError(t, err)
 	require.Equal(t, C.ProtocolDNS, metadata.Protocol)
 }
+
+func TestSniffStreamDNS(t *testing.T) {
+	t.Parallel()
+	query, err := hex.DecodeString("001e740701000001000000000000012a06676f6f676c6503636f6d0000010001")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
+	require.NoError(t, err)
+	require.Equal(t, C.ProtocolDNS, metadata.Protocol)
+}
+
+func TestSniffIncompleteStreamDNS(t *testing.T) {
+	t.Parallel()
+	query, err := hex.DecodeString("001e740701000001000000000000")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
+	require.ErrorIs(t, err, sniff.ErrNeedMoreData)
+}
+
+func TestSniffNotStreamDNS(t *testing.T) {
+	t.Parallel()
+	query, err := hex.DecodeString("001e740701000000000000000000")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
+	require.NotEmpty(t, err)
+	require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
+}

+ 1 - 1
common/sniff/sniff.go

@@ -68,7 +68,7 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
 			}
 			sniffError = E.Errors(sniffError, err)
 		}
-		if !errors.Is(err, ErrNeedMoreData) {
+		if !errors.Is(sniffError, ErrNeedMoreData) {
 			break
 		}
 	}

+ 3 - 2
common/sniff/ssh.go

@@ -15,10 +15,11 @@ func SSH(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	const sshPrefix = "SSH-2.0-"
 	bReader := bufio.NewReader(reader)
 	prefix, err := bReader.Peek(len(sshPrefix))
+	if string(prefix[:]) != sshPrefix[:len(prefix)] {
+		return os.ErrInvalid
+	}
 	if err != nil {
 		return E.Cause1(ErrNeedMoreData, err)
-	} else if string(prefix) != sshPrefix {
-		return os.ErrInvalid
 	}
 	fistLine, _, err := bReader.ReadLine()
 	if err != nil {

+ 21 - 0
common/sniff/ssh_test.go

@@ -24,3 +24,24 @@ func TestSniffSSH(t *testing.T) {
 	require.Equal(t, C.ProtocolSSH, metadata.Protocol)
 	require.Equal(t, "dropbear", metadata.Client)
 }
+
+func TestSniffIncompleteSSH(t *testing.T) {
+	t.Parallel()
+
+	pkt, err := hex.DecodeString("5353482d322e30")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt))
+	require.ErrorIs(t, err, sniff.ErrNeedMoreData)
+}
+
+func TestSniffNotSSH(t *testing.T) {
+	t.Parallel()
+
+	pkt, err := hex.DecodeString("5353482d322e31")
+	require.NoError(t, err)
+	var metadata adapter.InboundContext
+	err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt))
+	require.NotEmpty(t, err)
+	require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
+}