浏览代码

Improve sniffer

世界 6 月之前
父节点
当前提交
af17eaa537
共有 12 个文件被更改,包括 63 次插入48 次删除
  1. 3 2
      common/sniff/bittorrent.go
  2. 4 12
      common/sniff/dns.go
  3. 7 1
      common/sniff/http.go
  4. 1 3
      common/sniff/quic.go
  5. 1 1
      common/sniff/quic_test.go
  6. 10 9
      common/sniff/rdp.go
  7. 13 7
      common/sniff/sniff.go
  8. 11 7
      common/sniff/ssh.go
  9. 7 1
      common/sniff/tls.go
  10. 1 1
      go.mod
  11. 2 2
      go.sum
  12. 3 2
      route/route.go

+ 3 - 2
common/sniff/bittorrent.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	E "github.com/sagernet/sing/common/exceptions"
 )
 
 const (
@@ -23,7 +24,7 @@ func BitTorrent(_ context.Context, metadata *adapter.InboundContext, reader io.R
 	var first byte
 	err := binary.Read(reader, binary.BigEndian, &first)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 
 	if first != 19 {
@@ -33,7 +34,7 @@ func BitTorrent(_ context.Context, metadata *adapter.InboundContext, reader io.R
 	var protocol [19]byte
 	_, err = reader.Read(protocol[:])
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if string(protocol[:]) != "BitTorrent protocol" {
 		return os.ErrInvalid

+ 4 - 12
common/sniff/dns.go

@@ -5,13 +5,11 @@ import (
 	"encoding/binary"
 	"io"
 	"os"
-	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
-	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/task"
+	E "github.com/sagernet/sing/common/exceptions"
 
 	mDNS "github.com/miekg/dns"
 )
@@ -20,22 +18,16 @@ func StreamDomainNameQuery(readCtx context.Context, metadata *adapter.InboundCon
 	var length uint16
 	err := binary.Read(reader, binary.BigEndian, &length)
 	if err != nil {
-		return os.ErrInvalid
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if length == 0 {
 		return os.ErrInvalid
 	}
 	buffer := buf.NewSize(int(length))
 	defer buffer.Release()
-	readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
-	var readTask task.Group
-	readTask.Append0(func(ctx context.Context) error {
-		return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
-	})
-	err = readTask.Run(readCtx)
-	cancel()
+	_, err = buffer.ReadFullFrom(reader, buffer.FreeLen())
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	return DomainNameQuery(readCtx, metadata, buffer.Bytes())
 }

+ 7 - 1
common/sniff/http.go

@@ -3,10 +3,12 @@ package sniff
 import (
 	std_bufio "bufio"
 	"context"
+	"errors"
 	"io"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	"github.com/sagernet/sing/protocol/http"
 )
@@ -14,7 +16,11 @@ import (
 func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Reader) error {
 	request, err := http.ReadRequest(std_bufio.NewReader(reader))
 	if err != nil {
-		return err
+		if errors.Is(err, io.ErrUnexpectedEOF) {
+			return E.Cause1(ErrNeedMoreData, err)
+		} else {
+			return err
+		}
 	}
 	metadata.Protocol = C.ProtocolHTTP
 	metadata.Domain = M.ParseSocksaddr(request.Host).AddrString()

+ 1 - 3
common/sniff/quic.go

@@ -20,8 +20,6 @@ import (
 	"golang.org/x/crypto/hkdf"
 )
 
-var ErrClientHelloFragmented = E.New("need more packet for chromium QUIC connection")
-
 func QUICClientHello(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
 	reader := bytes.NewReader(packet)
 	typeByte, err := reader.ReadByte()
@@ -308,7 +306,7 @@ find:
 		metadata.Protocol = C.ProtocolQUIC
 		metadata.Client = C.ClientChromium
 		metadata.SniffContext = fragments
-		return ErrClientHelloFragmented
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	metadata.Domain = fingerprint.ServerName
 	for metadata.Client == "" {

文件差异内容过多而无法显示
+ 1 - 1
common/sniff/quic_test.go


+ 10 - 9
common/sniff/rdp.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/rw"
 )
 
@@ -15,7 +16,7 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var tpktVersion uint8
 	err := binary.Read(reader, binary.BigEndian, &tpktVersion)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if tpktVersion != 0x03 {
 		return os.ErrInvalid
@@ -24,7 +25,7 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var tpktReserved uint8
 	err = binary.Read(reader, binary.BigEndian, &tpktReserved)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if tpktReserved != 0x00 {
 		return os.ErrInvalid
@@ -33,7 +34,7 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var tpktLength uint16
 	err = binary.Read(reader, binary.BigEndian, &tpktLength)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 
 	if tpktLength != 19 {
@@ -43,7 +44,7 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var cotpLength uint8
 	err = binary.Read(reader, binary.BigEndian, &cotpLength)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 
 	if cotpLength != 14 {
@@ -53,7 +54,7 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var cotpTpduType uint8
 	err = binary.Read(reader, binary.BigEndian, &cotpTpduType)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if cotpTpduType != 0xE0 {
 		return os.ErrInvalid
@@ -61,13 +62,13 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 
 	err = rw.SkipN(reader, 5)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 
 	var rdpType uint8
 	err = binary.Read(reader, binary.BigEndian, &rdpType)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if rdpType != 0x01 {
 		return os.ErrInvalid
@@ -75,12 +76,12 @@ func RDP(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
 	var rdpFlags uint8
 	err = binary.Read(reader, binary.BigEndian, &rdpFlags)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	var rdpLength uint8
 	err = binary.Read(reader, binary.BigEndian, &rdpLength)
 	if err != nil {
-		return err
+		return E.Cause1(ErrNeedMoreData, err)
 	}
 	if rdpLength != 8 {
 		return os.ErrInvalid

+ 13 - 7
common/sniff/sniff.go

@@ -3,6 +3,7 @@ package sniff
 import (
 	"bytes"
 	"context"
+	"errors"
 	"io"
 	"net"
 	"time"
@@ -19,6 +20,8 @@ type (
 	PacketSniffer = func(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error
 )
 
+var ErrNeedMoreData = E.New("need more data")
+
 func Skip(metadata *adapter.InboundContext) bool {
 	// skip server first protocols
 	switch metadata.Destination.Port {
@@ -40,7 +43,7 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
 		timeout = C.ReadPayloadTimeout
 	}
 	deadline := time.Now().Add(timeout)
-	var errors []error
+	var sniffError error
 	for i := 0; ; i++ {
 		err := conn.SetReadDeadline(deadline)
 		if err != nil {
@@ -54,7 +57,7 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
 			}
 			return E.Cause(err, "read payload")
 		}
-		errors = nil
+		sniffError = nil
 		for _, sniffer := range sniffers {
 			reader := io.MultiReader(common.Map(append(buffers, buffer), func(it *buf.Buffer) io.Reader {
 				return bytes.NewReader(it.Bytes())
@@ -63,20 +66,23 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
 			if err == nil {
 				return nil
 			}
-			errors = append(errors, err)
+			sniffError = E.Errors(sniffError, err)
+		}
+		if !errors.Is(err, ErrNeedMoreData) {
+			break
 		}
 	}
-	return E.Errors(errors...)
+	return sniffError
 }
 
 func PeekPacket(ctx context.Context, metadata *adapter.InboundContext, packet []byte, sniffers ...PacketSniffer) error {
-	var errors []error
+	var sniffError []error
 	for _, sniffer := range sniffers {
 		err := sniffer(ctx, metadata, packet)
 		if err == nil {
 			return nil
 		}
-		errors = append(errors, err)
+		sniffError = append(sniffError, err)
 	}
-	return E.Errors(errors...)
+	return E.Errors(sniffError...)
 }

+ 11 - 7
common/sniff/ssh.go

@@ -5,22 +5,26 @@ import (
 	"context"
 	"io"
 	"os"
-	"strings"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	E "github.com/sagernet/sing/common/exceptions"
 )
 
 func SSH(_ context.Context, metadata *adapter.InboundContext, reader io.Reader) error {
-	scanner := bufio.NewScanner(reader)
-	if !scanner.Scan() {
+	const sshPrefix = "SSH-2.0-"
+	bReader := bufio.NewReader(reader)
+	prefix, err := bReader.Peek(len(sshPrefix))
+	if err != nil {
+		return E.Cause1(ErrNeedMoreData, err)
+	} else if string(prefix) != sshPrefix {
 		return os.ErrInvalid
 	}
-	fistLine := scanner.Text()
-	if !strings.HasPrefix(fistLine, "SSH-2.0-") {
-		return os.ErrInvalid
+	fistLine, _, err := bReader.ReadLine()
+	if err != nil {
+		return err
 	}
 	metadata.Protocol = C.ProtocolSSH
-	metadata.Client = fistLine[8:]
+	metadata.Client = string(fistLine)[8:]
 	return nil
 }

+ 7 - 1
common/sniff/tls.go

@@ -3,11 +3,13 @@ package sniff
 import (
 	"context"
 	"crypto/tls"
+	"errors"
 	"io"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
 )
 
 func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reader io.Reader) error {
@@ -23,5 +25,9 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade
 		metadata.Domain = clientHello.ServerName
 		return nil
 	}
-	return err
+	if errors.Is(err, io.ErrUnexpectedEOF) {
+		return E.Cause1(ErrNeedMoreData, err)
+	} else {
+		return err
+	}
 }

+ 1 - 1
go.mod

@@ -26,7 +26,7 @@ require (
 	github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
 	github.com/sagernet/quic-go v0.49.0-beta.1
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
-	github.com/sagernet/sing v0.6.5
+	github.com/sagernet/sing v0.6.6-0.20250406082302-d3673bff4af8
 	github.com/sagernet/sing-dns v0.4.1
 	github.com/sagernet/sing-mux v0.3.1
 	github.com/sagernet/sing-quic v0.4.1

+ 2 - 2
go.sum

@@ -119,8 +119,8 @@ github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8W
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc=
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
 github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
-github.com/sagernet/sing v0.6.5 h1:TBKTK6Ms0/MNTZm+cTC2hhKunE42XrNIdsxcYtWqeUU=
-github.com/sagernet/sing v0.6.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
+github.com/sagernet/sing v0.6.6-0.20250406082302-d3673bff4af8 h1:1jHChanwnGF5DJZ5pR/RkVf69VyjQxfDVfOMJx7bPyI=
+github.com/sagernet/sing v0.6.6-0.20250406082302-d3673bff4af8/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
 github.com/sagernet/sing-dns v0.4.1 h1:nozS7iqpxZ7aV73oHbkD/8haOvf3XXDCgT//8NdYirk=
 github.com/sagernet/sing-dns v0.4.1/go.mod h1:dweQs54ng2YGzoJfz+F9dGuDNdP5pJ3PLeggnK5VWc8=
 github.com/sagernet/sing-mux v0.3.1 h1:kvCc8HyGAskDHDQ0yQvoTi/7J4cZPB/VJMsAM3MmdQI=

+ 3 - 2
route/route.go

@@ -549,7 +549,7 @@ func (r *Router) actionSniff(
 			sniffBuffer.Release()
 		}
 	} else if inputPacketConn != nil {
-		if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrClientHelloFragmented) {
+		if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrNeedMoreData) {
 			r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.PacketSniffError)
 			return
 		}
@@ -618,7 +618,8 @@ func (r *Router) actionSniff(
 				}
 				packetBuffers = append(packetBuffers, packetBuffer)
 				metadata.PacketSniffError = err
-				if errors.Is(err, sniff.ErrClientHelloFragmented) {
+				if errors.Is(err, sniff.ErrNeedMoreData) {
+					// TODO: replace with generic message when there are more multi-packet protocols
 					r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
 					continue
 				}

部分文件因为文件数量过多而无法显示