Sfoglia il codice sorgente

Fix multiple sniff

世界 1 mese fa
parent
commit
cbf48e9b8c
3 ha cambiato i file con 75 aggiunte e 41 eliminazioni
  1. 1 0
      adapter/inbound.go
  2. 66 33
      route/route.go
  3. 8 8
      route/rule/rule_action.go

+ 1 - 0
adapter/inbound.go

@@ -57,6 +57,7 @@ type InboundContext struct {
 	Domain       string
 	Client       string
 	SniffContext any
+	SnifferNames []string
 	SniffError   error
 
 	// cache

+ 66 - 33
route/route.go

@@ -27,6 +27,8 @@ import (
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/uot"
+
+	"golang.org/x/exp/slices"
 )
 
 // Deprecated: use RouteConnectionEx instead.
@@ -345,16 +347,16 @@ func (r *Router) matchRule(
 			newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &R.RuleActionSniff{
 				OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
 				Timeout:             time.Duration(metadata.InboundOptions.SniffTimeout),
-			}, inputConn, inputPacketConn, nil)
-			if newErr != nil {
-				fatalErr = newErr
-				return
-			}
+			}, inputConn, inputPacketConn, nil, nil)
 			if newBuffer != nil {
 				buffers = []*buf.Buffer{newBuffer}
 			} else if len(newPackerBuffers) > 0 {
 				packetBuffers = newPackerBuffers
 			}
+			if newErr != nil {
+				fatalErr = newErr
+				return
+			}
 		}
 		if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS {
 			fatalErr = r.actionResolve(ctx, metadata, &R.RuleActionResolve{
@@ -453,16 +455,16 @@ match:
 		switch action := currentRule.Action().(type) {
 		case *R.RuleActionSniff:
 			if !preMatch {
-				newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers)
-				if newErr != nil {
-					fatalErr = newErr
-					return
-				}
+				newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers, packetBuffers)
 				if newBuffer != nil {
 					buffers = append(buffers, newBuffer)
 				} else if len(newPacketBuffers) > 0 {
 					packetBuffers = append(packetBuffers, newPacketBuffers...)
 				}
+				if newErr != nil {
+					fatalErr = newErr
+					return
+				}
 			} else {
 				selectedRule = currentRule
 				selectedRuleIndex = currentRuleIndex
@@ -489,7 +491,7 @@ match:
 
 func (r *Router) actionSniff(
 	ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff,
-	inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer,
+	inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputPacketBuffers []*N.PacketBuffer,
 ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
 	if sniff.Skip(metadata) {
 		r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
@@ -501,7 +503,7 @@ func (r *Router) actionSniff(
 	if inputConn != nil {
 		if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 {
 			return
-		} else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
+		} else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
 			r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
 			return
 		}
@@ -528,6 +530,7 @@ func (r *Router) actionSniff(
 			action.Timeout,
 			streamSniffers...,
 		)
+		metadata.SnifferNames = action.SnifferNames
 		metadata.SniffError = err
 		if err == nil {
 			//goland:noinspection GoDeprecation
@@ -553,10 +556,13 @@ func (r *Router) actionSniff(
 	} else if inputPacketConn != nil {
 		if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 {
 			return
-		} else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
+		} else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
 			r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
 			return
 		}
+		quicMoreData := func() bool {
+			return slices.Equal(metadata.SnifferNames, action.SnifferNames) && errors.Is(metadata.SniffError, sniff.ErrNeedMoreData)
+		}
 		var packetSniffers []sniff.PacketSniffer
 		if len(action.PacketSniffers) > 0 {
 			packetSniffers = action.PacketSniffers
@@ -571,12 +577,37 @@ func (r *Router) actionSniff(
 				sniff.NTP,
 			}
 		}
+		var err error
+		for _, packetBuffer := range inputPacketBuffers {
+			if quicMoreData() {
+				err = sniff.PeekPacket(
+					ctx,
+					metadata,
+					packetBuffer.Buffer.Bytes(),
+					sniff.QUICClientHello,
+				)
+			} else {
+				err = sniff.PeekPacket(
+					ctx, metadata,
+					packetBuffer.Buffer.Bytes(),
+					packetSniffers...,
+				)
+			}
+			metadata.SnifferNames = action.SnifferNames
+			metadata.SniffError = err
+			if errors.Is(err, sniff.ErrNeedMoreData) {
+				// TODO: replace with generic message when there are more multi-packet protocols
+				r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
+				continue
+			}
+			goto finally
+		}
+		packetBuffers = inputPacketBuffers
 		for {
 			var (
 				sniffBuffer = buf.NewPacket()
 				destination M.Socksaddr
 				done        = make(chan struct{})
-				err         error
 			)
 			go func() {
 				sniffTimeout := C.ReadPayloadTimeout
@@ -602,7 +633,7 @@ func (r *Router) actionSniff(
 					return
 				}
 			} else {
-				if len(packetBuffers) > 0 || metadata.SniffError != nil {
+				if quicMoreData() {
 					err = sniff.PeekPacket(
 						ctx,
 						metadata,
@@ -622,32 +653,34 @@ func (r *Router) actionSniff(
 					Destination: destination,
 				}
 				packetBuffers = append(packetBuffers, packetBuffer)
+				metadata.SnifferNames = action.SnifferNames
 				metadata.SniffError = err
 				if errors.Is(err, sniff.ErrNeedMoreData) {
 					// TODO: replace with generic message when there are more multi-packet protocols
 					r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
 					continue
 				}
-				if metadata.Protocol != "" {
-					//goland:noinspection GoDeprecation
-					if action.OverrideDestination && M.IsDomainName(metadata.Domain) {
-						metadata.Destination = M.Socksaddr{
-							Fqdn: metadata.Domain,
-							Port: metadata.Destination.Port,
-						}
-					}
-					if metadata.Domain != "" && metadata.Client != "" {
-						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client)
-					} else if metadata.Domain != "" {
-						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
-					} else if metadata.Client != "" {
-						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
-					} else {
-						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
-					}
+			}
+			goto finally
+		}
+	finally:
+		if err == nil {
+			//goland:noinspection GoDeprecation
+			if action.OverrideDestination && M.IsDomainName(metadata.Domain) {
+				metadata.Destination = M.Socksaddr{
+					Fqdn: metadata.Domain,
+					Port: metadata.Destination.Port,
 				}
 			}
-			break
+			if metadata.Domain != "" && metadata.Client != "" {
+				r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client)
+			} else if metadata.Domain != "" {
+				r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
+			} else if metadata.Client != "" {
+				r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
+			} else {
+				r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
+			}
 		}
 	}
 	return

+ 8 - 8
route/rule/rule_action.go

@@ -87,7 +87,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
 		return &RuleActionHijackDNS{}, nil
 	case C.RuleActionTypeSniff:
 		sniffAction := &RuleActionSniff{
-			snifferNames: action.SniffOptions.Sniffer,
+			SnifferNames: action.SniffOptions.Sniffer,
 			Timeout:      time.Duration(action.SniffOptions.Timeout),
 		}
 		return sniffAction, sniffAction.build()
@@ -361,7 +361,7 @@ func (r *RuleActionHijackDNS) String() string {
 }
 
 type RuleActionSniff struct {
-	snifferNames   []string
+	SnifferNames   []string
 	StreamSniffers []sniff.StreamSniffer
 	PacketSniffers []sniff.PacketSniffer
 	Timeout        time.Duration
@@ -374,7 +374,7 @@ func (r *RuleActionSniff) Type() string {
 }
 
 func (r *RuleActionSniff) build() error {
-	for _, name := range r.snifferNames {
+	for _, name := range r.SnifferNames {
 		switch name {
 		case C.ProtocolTLS:
 			r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello)
@@ -407,14 +407,14 @@ func (r *RuleActionSniff) build() error {
 }
 
 func (r *RuleActionSniff) String() string {
-	if len(r.snifferNames) == 0 && r.Timeout == 0 {
+	if len(r.SnifferNames) == 0 && r.Timeout == 0 {
 		return "sniff"
-	} else if len(r.snifferNames) > 0 && r.Timeout == 0 {
-		return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ")")
-	} else if len(r.snifferNames) == 0 && r.Timeout > 0 {
+	} else if len(r.SnifferNames) > 0 && r.Timeout == 0 {
+		return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ")")
+	} else if len(r.SnifferNames) == 0 && r.Timeout > 0 {
 		return F.ToString("sniff(", r.Timeout.String(), ")")
 	} else {
-		return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ",", r.Timeout.String(), ")")
+		return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ",", r.Timeout.String(), ")")
 	}
 }