|
@@ -27,6 +27,8 @@ import (
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
"github.com/sagernet/sing/common/uot"
|
|
|
+
|
|
|
+ "golang.org/x/exp/slices"
|
|
|
)
|
|
|
|
|
|
// Deprecated: use RouteConnectionEx instead.
|
|
@@ -345,16 +347,16 @@ func (r *Router) matchRule(
|
|
|
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &R.RuleActionSniff{
|
|
|
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
|
|
|
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
|
|
|
- }, inputConn, inputPacketConn, nil)
|
|
|
- if newErr != nil {
|
|
|
- fatalErr = newErr
|
|
|
- return
|
|
|
- }
|
|
|
+ }, inputConn, inputPacketConn, nil, nil)
|
|
|
if newBuffer != nil {
|
|
|
buffers = []*buf.Buffer{newBuffer}
|
|
|
} else if len(newPackerBuffers) > 0 {
|
|
|
packetBuffers = newPackerBuffers
|
|
|
}
|
|
|
+ if newErr != nil {
|
|
|
+ fatalErr = newErr
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS {
|
|
|
fatalErr = r.actionResolve(ctx, metadata, &R.RuleActionResolve{
|
|
@@ -453,16 +455,16 @@ match:
|
|
|
switch action := currentRule.Action().(type) {
|
|
|
case *R.RuleActionSniff:
|
|
|
if !preMatch {
|
|
|
- newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers)
|
|
|
- if newErr != nil {
|
|
|
- fatalErr = newErr
|
|
|
- return
|
|
|
- }
|
|
|
+ newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers, packetBuffers)
|
|
|
if newBuffer != nil {
|
|
|
buffers = append(buffers, newBuffer)
|
|
|
} else if len(newPacketBuffers) > 0 {
|
|
|
packetBuffers = append(packetBuffers, newPacketBuffers...)
|
|
|
}
|
|
|
+ if newErr != nil {
|
|
|
+ fatalErr = newErr
|
|
|
+ return
|
|
|
+ }
|
|
|
} else {
|
|
|
selectedRule = currentRule
|
|
|
selectedRuleIndex = currentRuleIndex
|
|
@@ -489,7 +491,7 @@ match:
|
|
|
|
|
|
func (r *Router) actionSniff(
|
|
|
ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff,
|
|
|
- inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer,
|
|
|
+ inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputPacketBuffers []*N.PacketBuffer,
|
|
|
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
|
|
|
if sniff.Skip(metadata) {
|
|
|
r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
|
|
@@ -501,7 +503,7 @@ func (r *Router) actionSniff(
|
|
|
if inputConn != nil {
|
|
|
if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 {
|
|
|
return
|
|
|
- } else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
|
|
|
+ } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
|
|
|
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
|
|
|
return
|
|
|
}
|
|
@@ -528,6 +530,7 @@ func (r *Router) actionSniff(
|
|
|
action.Timeout,
|
|
|
streamSniffers...,
|
|
|
)
|
|
|
+ metadata.SnifferNames = action.SnifferNames
|
|
|
metadata.SniffError = err
|
|
|
if err == nil {
|
|
|
//goland:noinspection GoDeprecation
|
|
@@ -553,10 +556,13 @@ func (r *Router) actionSniff(
|
|
|
} else if inputPacketConn != nil {
|
|
|
if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 {
|
|
|
return
|
|
|
- } else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
|
|
|
+ } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
|
|
|
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
|
|
|
return
|
|
|
}
|
|
|
+ quicMoreData := func() bool {
|
|
|
+ return slices.Equal(metadata.SnifferNames, action.SnifferNames) && errors.Is(metadata.SniffError, sniff.ErrNeedMoreData)
|
|
|
+ }
|
|
|
var packetSniffers []sniff.PacketSniffer
|
|
|
if len(action.PacketSniffers) > 0 {
|
|
|
packetSniffers = action.PacketSniffers
|
|
@@ -571,12 +577,37 @@ func (r *Router) actionSniff(
|
|
|
sniff.NTP,
|
|
|
}
|
|
|
}
|
|
|
+ var err error
|
|
|
+ for _, packetBuffer := range inputPacketBuffers {
|
|
|
+ if quicMoreData() {
|
|
|
+ err = sniff.PeekPacket(
|
|
|
+ ctx,
|
|
|
+ metadata,
|
|
|
+ packetBuffer.Buffer.Bytes(),
|
|
|
+ sniff.QUICClientHello,
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ err = sniff.PeekPacket(
|
|
|
+ ctx, metadata,
|
|
|
+ packetBuffer.Buffer.Bytes(),
|
|
|
+ packetSniffers...,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ metadata.SnifferNames = action.SnifferNames
|
|
|
+ metadata.SniffError = err
|
|
|
+ if errors.Is(err, sniff.ErrNeedMoreData) {
|
|
|
+ // TODO: replace with generic message when there are more multi-packet protocols
|
|
|
+ r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ goto finally
|
|
|
+ }
|
|
|
+ packetBuffers = inputPacketBuffers
|
|
|
for {
|
|
|
var (
|
|
|
sniffBuffer = buf.NewPacket()
|
|
|
destination M.Socksaddr
|
|
|
done = make(chan struct{})
|
|
|
- err error
|
|
|
)
|
|
|
go func() {
|
|
|
sniffTimeout := C.ReadPayloadTimeout
|
|
@@ -602,7 +633,7 @@ func (r *Router) actionSniff(
|
|
|
return
|
|
|
}
|
|
|
} else {
|
|
|
- if len(packetBuffers) > 0 || metadata.SniffError != nil {
|
|
|
+ if quicMoreData() {
|
|
|
err = sniff.PeekPacket(
|
|
|
ctx,
|
|
|
metadata,
|
|
@@ -622,32 +653,34 @@ func (r *Router) actionSniff(
|
|
|
Destination: destination,
|
|
|
}
|
|
|
packetBuffers = append(packetBuffers, packetBuffer)
|
|
|
+ metadata.SnifferNames = action.SnifferNames
|
|
|
metadata.SniffError = err
|
|
|
if errors.Is(err, sniff.ErrNeedMoreData) {
|
|
|
// TODO: replace with generic message when there are more multi-packet protocols
|
|
|
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
|
|
|
continue
|
|
|
}
|
|
|
- if metadata.Protocol != "" {
|
|
|
- //goland:noinspection GoDeprecation
|
|
|
- if action.OverrideDestination && M.IsDomainName(metadata.Domain) {
|
|
|
- metadata.Destination = M.Socksaddr{
|
|
|
- Fqdn: metadata.Domain,
|
|
|
- Port: metadata.Destination.Port,
|
|
|
- }
|
|
|
- }
|
|
|
- if metadata.Domain != "" && metadata.Client != "" {
|
|
|
- r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client)
|
|
|
- } else if metadata.Domain != "" {
|
|
|
- r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
|
|
|
- } else if metadata.Client != "" {
|
|
|
- r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
|
|
|
- } else {
|
|
|
- r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
|
|
|
- }
|
|
|
+ }
|
|
|
+ goto finally
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ if err == nil {
|
|
|
+ //goland:noinspection GoDeprecation
|
|
|
+ if action.OverrideDestination && M.IsDomainName(metadata.Domain) {
|
|
|
+ metadata.Destination = M.Socksaddr{
|
|
|
+ Fqdn: metadata.Domain,
|
|
|
+ Port: metadata.Destination.Port,
|
|
|
}
|
|
|
}
|
|
|
- break
|
|
|
+ if metadata.Domain != "" && metadata.Client != "" {
|
|
|
+ r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client)
|
|
|
+ } else if metadata.Domain != "" {
|
|
|
+ r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
|
|
|
+ } else if metadata.Client != "" {
|
|
|
+ r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
|
|
|
+ } else {
|
|
|
+ r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
return
|