浏览代码

Fix WireGuard panic

世界 8 月之前
父节点
当前提交
a0dc394c8f
共有 5 个文件被更改,包括 77 次插入91 次删除
  1. 0 7
      common/dialer/default.go
  2. 50 80
      common/dialer/dialer.go
  3. 11 2
      dns/transport_dialer.go
  4. 8 1
      protocol/wireguard/endpoint.go
  5. 8 1
      protocol/wireguard/outbound.go

+ 0 - 7
common/dialer/default.go

@@ -35,7 +35,6 @@ type DefaultDialer struct {
 	udpListener            net.ListenConfig
 	udpAddr4               string
 	udpAddr6               string
-	isWireGuardListener    bool
 	networkManager         adapter.NetworkManager
 	networkStrategy        *C.NetworkStrategy
 	defaultNetworkStrategy bool
@@ -183,11 +182,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
 		}
 		setMultiPathTCP(&dialer4)
 	}
-	if options.IsWireGuardListener {
-		for _, controlFn := range WgControlFns {
-			listener.Control = control.Append(listener.Control, controlFn)
-		}
-	}
 	tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
 	if err != nil {
 		return nil, err
@@ -204,7 +198,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
 		udpListener:            listener,
 		udpAddr4:               udpAddr4,
 		udpAddr6:               udpAddr6,
-		isWireGuardListener:    options.IsWireGuardListener,
 		networkManager:         networkManager,
 		networkStrategy:        networkStrategy,
 		defaultNetworkStrategy: defaultNetworkStrategy,

+ 50 - 80
common/dialer/dialer.go

@@ -16,59 +16,82 @@ import (
 	"github.com/sagernet/sing/service"
 )
 
+type Options struct {
+	Context          context.Context
+	Options          option.DialerOptions
+	RemoteIsDomain   bool
+	DirectResolver   bool
+	ResolverOnDetour bool
+}
+
+// TODO: merge with NewWithOptions
 func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
-	if options.IsWireGuardListener {
-		return NewDefault(ctx, options)
-	}
+	return NewWithOptions(Options{
+		Context:        ctx,
+		Options:        options,
+		RemoteIsDomain: remoteIsDomain,
+	})
+}
+
+func NewWithOptions(options Options) (N.Dialer, error) {
+	dialOptions := options.Options
 	var (
 		dialer N.Dialer
 		err    error
 	)
-	if options.Detour != "" {
-		outboundManager := service.FromContext[adapter.OutboundManager](ctx)
+	if dialOptions.Detour != "" {
+		outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
 		if outboundManager == nil {
 			return nil, E.New("missing outbound manager")
 		}
-		dialer = NewDetour(outboundManager, options.Detour)
+		dialer = NewDetour(outboundManager, dialOptions.Detour)
 	} else {
-		dialer, err = NewDefault(ctx, options)
+		dialer, err = NewDefault(options.Context, dialOptions)
 		if err != nil {
 			return nil, err
 		}
 	}
-	if remoteIsDomain && options.Detour == "" {
-		networkManager := service.FromContext[adapter.NetworkManager](ctx)
-		dnsTransport := service.FromContext[adapter.DNSTransportManager](ctx)
+	if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour) {
+		networkManager := service.FromContext[adapter.NetworkManager](options.Context)
+		dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
 		var defaultOptions adapter.NetworkOptions
 		if networkManager != nil {
 			defaultOptions = networkManager.DefaultOptions()
 		}
 		var (
+			server               string
 			dnsQueryOptions      adapter.DNSQueryOptions
 			resolveFallbackDelay time.Duration
 		)
-		if options.DomainResolver != nil && options.DomainResolver.Server != "" {
-			transport, loaded := dnsTransport.Transport(options.DomainResolver.Server)
-			if !loaded {
-				return nil, E.New("domain resolver not found: " + options.DomainResolver.Server)
+		if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" {
+			var transport adapter.DNSTransport
+			if !options.DirectResolver {
+				var loaded bool
+				transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server)
+				if !loaded {
+					return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server)
+				}
 			}
 			var strategy C.DomainStrategy
-			if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
-				strategy = C.DomainStrategy(options.DomainResolver.Strategy)
+			if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
+				strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy)
 			} else if
 			//nolint:staticcheck
-			options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
+			dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
 				//nolint:staticcheck
-				strategy = C.DomainStrategy(options.DomainStrategy)
+				strategy = C.DomainStrategy(dialOptions.DomainStrategy)
 			}
+			server = dialOptions.DomainResolver.Server
 			dnsQueryOptions = adapter.DNSQueryOptions{
 				Transport:    transport,
 				Strategy:     strategy,
-				DisableCache: options.DomainResolver.DisableCache,
-				RewriteTTL:   options.DomainResolver.RewriteTTL,
-				ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
+				DisableCache: dialOptions.DomainResolver.DisableCache,
+				RewriteTTL:   dialOptions.DomainResolver.RewriteTTL,
+				ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
 			}
-			resolveFallbackDelay = time.Duration(options.FallbackDelay)
+			resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
+		} else if options.DirectResolver {
+			return nil, E.New("missing domain resolver for domain server address")
 		} else if defaultOptions.DomainResolver != "" {
 			dnsQueryOptions = defaultOptions.DomainResolveOptions
 			transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver)
@@ -76,68 +99,15 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool)
 				return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver)
 			}
 			dnsQueryOptions.Transport = transport
-			resolveFallbackDelay = time.Duration(options.FallbackDelay)
+			resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
 		} else {
-			deprecated.Report(ctx, deprecated.OptionMissingDomainResolver)
-		}
-		dialer = NewResolveDialer(
-			ctx,
-			dialer,
-			options.Detour == "" && !options.TCPFastOpen,
-			"",
-			dnsQueryOptions,
-			resolveFallbackDelay,
-		)
-	}
-	return dialer, nil
-}
-
-func NewDNS(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
-	var (
-		dialer N.Dialer
-		err    error
-	)
-	if options.Detour != "" {
-		outboundManager := service.FromContext[adapter.OutboundManager](ctx)
-		if outboundManager == nil {
-			return nil, E.New("missing outbound manager")
-		}
-		dialer = NewDetour(outboundManager, options.Detour)
-	} else {
-		dialer, err = NewDefault(ctx, options)
-		if err != nil {
-			return nil, err
-		}
-	}
-	if remoteIsDomain {
-		var (
-			dnsQueryOptions      adapter.DNSQueryOptions
-			resolveFallbackDelay time.Duration
-		)
-		if options.DomainResolver == nil || options.DomainResolver.Server == "" {
-			return nil, E.New("missing domain resolver for domain server address")
-		}
-		var strategy C.DomainStrategy
-		if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
-			strategy = C.DomainStrategy(options.DomainResolver.Strategy)
-		} else if
-		//nolint:staticcheck
-		options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
-			//nolint:staticcheck
-			strategy = C.DomainStrategy(options.DomainStrategy)
-		}
-		dnsQueryOptions = adapter.DNSQueryOptions{
-			Strategy:     strategy,
-			DisableCache: options.DomainResolver.DisableCache,
-			RewriteTTL:   options.DomainResolver.RewriteTTL,
-			ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
+			deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver)
 		}
-		resolveFallbackDelay = time.Duration(options.FallbackDelay)
 		dialer = NewResolveDialer(
-			ctx,
+			options.Context,
 			dialer,
-			options.Detour == "" && !options.TCPFastOpen,
-			options.DomainResolver.Server,
+			dialOptions.Detour == "" && !dialOptions.TCPFastOpen,
+			server,
 			dnsQueryOptions,
 			resolveFallbackDelay,
 		)

+ 11 - 2
dns/transport_dialer.go

@@ -19,7 +19,11 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
 	if options.LegacyDefaultDialer {
 		return dialer.NewDefaultOutbound(ctx), nil
 	} else {
-		return dialer.NewDNS(ctx, options.DialerOptions, false)
+		return dialer.NewWithOptions(dialer.Options{
+			Context:        ctx,
+			Options:        options.DialerOptions,
+			DirectResolver: true,
+		})
 	}
 }
 
@@ -38,7 +42,12 @@ func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions)
 		}
 		return transportDialer, nil
 	} else {
-		return dialer.NewDNS(ctx, options.DialerOptions, options.ServerIsDomain())
+		return dialer.NewWithOptions(dialer.Options{
+			Context:        ctx,
+			Options:        options.DialerOptions,
+			RemoteIsDomain: options.ServerIsDomain(),
+			DirectResolver: true,
+		})
 	}
 }
 

+ 8 - 1
protocol/wireguard/endpoint.go

@@ -53,7 +53,14 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
 	if options.Detour == "" {
 		options.IsWireGuardListener = true
 	}
-	outboundDialer, err := dialer.New(ctx, options.DialerOptions, false)
+	outboundDialer, err := dialer.NewWithOptions(dialer.Options{
+		Context: ctx,
+		Options: options.DialerOptions,
+		RemoteIsDomain: common.Any(options.Peers, func(it option.WireGuardPeer) bool {
+			return !M.ParseAddr(it.Address).IsValid()
+		}),
+		ResolverOnDetour: true,
+	})
 	if err != nil {
 		return nil, err
 	}

+ 8 - 1
protocol/wireguard/outbound.go

@@ -56,7 +56,14 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 	} else if options.GSO {
 		return nil, E.New("gso is conflict with detour")
 	}
-	outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
+	outboundDialer, err := dialer.NewWithOptions(dialer.Options{
+		Context: ctx,
+		Options: options.DialerOptions,
+		RemoteIsDomain: options.ServerIsDomain() || common.Any(options.Peers, func(it option.LegacyWireGuardPeer) bool {
+			return it.ServerIsDomain()
+		}),
+		ResolverOnDetour: true,
+	})
 	if err != nil {
 		return nil, err
 	}