Selaa lähdekoodia

Fix processing multiple sniffs

世界 7 kuukautta sitten
vanhempi
sitoutus
d55d5009c2
3 muutettua tiedostoa jossa 29 lisäystä ja 13 poistoa
  1. 5 4
      adapter/inbound.go
  2. 6 2
      common/sniff/sniff.go
  3. 18 7
      route/route.go

+ 5 - 4
adapter/inbound.go

@@ -53,10 +53,11 @@ type InboundContext struct {
 
 	// sniffer
 
-	Protocol     string
-	Domain       string
-	Client       string
-	SniffContext any
+	Protocol         string
+	Domain           string
+	Client           string
+	SniffContext     any
+	PacketSniffError error
 
 	// cache
 

+ 6 - 2
common/sniff/sniff.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 )
@@ -34,7 +35,7 @@ func Skip(metadata *adapter.InboundContext) bool {
 	return false
 }
 
-func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error {
+func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffers []*buf.Buffer, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error {
 	if timeout == 0 {
 		timeout = C.ReadPayloadTimeout
 	}
@@ -55,7 +56,10 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
 		}
 		errors = nil
 		for _, sniffer := range sniffers {
-			err = sniffer(ctx, metadata, bytes.NewReader(buffer.Bytes()))
+			reader := io.MultiReader(common.Map(append(buffers, buffer), func(it *buf.Buffer) io.Reader {
+				return bytes.NewReader(it.Bytes())
+			})...)
+			err = sniffer(ctx, metadata, reader)
 			if err == nil {
 				return nil
 			}

+ 18 - 7
route/route.go

@@ -358,7 +358,7 @@ func (r *Router) matchRule(
 			newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
 				OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
 				Timeout:             time.Duration(metadata.InboundOptions.SniffTimeout),
-			}, inputConn, inputPacketConn)
+			}, inputConn, inputPacketConn, nil)
 			if newErr != nil {
 				fatalErr = newErr
 				return
@@ -458,7 +458,7 @@ match:
 		switch action := currentRule.Action().(type) {
 		case *rule.RuleActionSniff:
 			if !preMatch {
-				newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
+				newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers)
 				if newErr != nil {
 					fatalErr = newErr
 					return
@@ -490,7 +490,7 @@ match:
 		}
 	}
 	if !preMatch && inputPacketConn != nil && (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() {
-		newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn)
+		newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn, buffers)
 		if newErr != nil {
 			fatalErr = newErr
 			return
@@ -506,11 +506,16 @@ match:
 
 func (r *Router) actionSniff(
 	ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
-	inputConn net.Conn, inputPacketConn N.PacketConn,
+	inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer,
 ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
 	if sniff.Skip(metadata) {
+		r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
+		return
+	} else if metadata.Protocol != "" {
+		r.logger.DebugContext(ctx, "duplicate sniff skipped")
 		return
-	} else if inputConn != nil {
+	}
+	if inputConn != nil {
 		sniffBuffer := buf.NewPacket()
 		var streamSniffers []sniff.StreamSniffer
 		if len(action.StreamSniffers) > 0 {
@@ -529,6 +534,7 @@ func (r *Router) actionSniff(
 			ctx,
 			metadata,
 			inputConn,
+			inputBuffers,
 			sniffBuffer,
 			action.Timeout,
 			streamSniffers...,
@@ -555,6 +561,10 @@ func (r *Router) actionSniff(
 			sniffBuffer.Release()
 		}
 	} else if inputPacketConn != nil {
+		if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrClientHelloFragmented) {
+			r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.PacketSniffError)
+			return
+		}
 		for {
 			var (
 				sniffBuffer = buf.NewPacket()
@@ -589,7 +599,7 @@ func (r *Router) actionSniff(
 				if (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() {
 					metadata.Destination = destination
 				}
-				if len(packetBuffers) > 0 {
+				if len(packetBuffers) > 0 || metadata.PacketSniffError != nil {
 					err = sniff.PeekPacket(
 						ctx,
 						metadata,
@@ -622,7 +632,8 @@ func (r *Router) actionSniff(
 					Destination: destination,
 				}
 				packetBuffers = append(packetBuffers, packetBuffer)
-				if E.IsMulti(err, sniff.ErrClientHelloFragmented) {
+				metadata.PacketSniffError = err
+				if errors.Is(err, sniff.ErrClientHelloFragmented) {
 					r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
 					continue
 				}