Просмотр исходного кода

Fix domain resolver for DNS server

世界 8 месяцев назад
Родитель
Сommit
244243f206
7 измененных файлов с 180 добавлено и 64 удалено
  1. 2 2
      box.go
  2. 9 7
      common/dialer/detour.go
  3. 54 0
      common/dialer/dialer.go
  4. 55 12
      common/dialer/resolve.go
  5. 10 2
      dns/transport_adapter.go
  6. 20 27
      dns/transport_dialer.go
  7. 30 14
      option/dns.go

+ 2 - 2
box.go

@@ -202,7 +202,7 @@ func New(options Options) (*Box, error) {
 			transportOptions.Options,
 		)
 		if err != nil {
-			return nil, E.Cause(err, "initialize inbound[", i, "]")
+			return nil, E.Cause(err, "initialize DNS server[", i, "]")
 		}
 	}
 	err = dnsRouter.Initialize(dnsOptions.Rules)
@@ -225,7 +225,7 @@ func New(options Options) (*Box, error) {
 			endpointOptions.Options,
 		)
 		if err != nil {
-			return nil, E.Cause(err, "initialize inbound[", i, "]")
+			return nil, E.Cause(err, "initialize endpoint[", i, "]")
 		}
 	}
 	for i, inboundOptions := range options.Inbounds {

+ 9 - 7
common/dialer/detour.go

@@ -29,16 +29,18 @@ func (d *DetourDialer) Start() error {
 }
 
 func (d *DetourDialer) Dialer() (N.Dialer, error) {
-	d.initOnce.Do(func() {
-		var loaded bool
-		d.dialer, loaded = d.outboundManager.Outbound(d.detour)
-		if !loaded {
-			d.initErr = E.New("outbound detour not found: ", d.detour)
-		}
-	})
+	d.initOnce.Do(d.init)
 	return d.dialer, d.initErr
 }
 
+func (d *DetourDialer) init() {
+	var loaded bool
+	d.dialer, loaded = d.outboundManager.Outbound(d.detour)
+	if !loaded {
+		d.initErr = E.New("outbound detour not found: ", d.detour)
+	}
+}
+
 func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
 	dialer, err := d.Dialer()
 	if err != nil {

+ 54 - 0
common/dialer/dialer.go

@@ -84,6 +84,60 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool)
 			ctx,
 			dialer,
 			options.Detour == "" && !options.TCPFastOpen,
+			"",
+			dnsQueryOptions,
+			resolveFallbackDelay,
+		)
+	}
+	return dialer, nil
+}
+
+func NewDNS(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
+	var (
+		dialer N.Dialer
+		err    error
+	)
+	if options.Detour != "" {
+		outboundManager := service.FromContext[adapter.OutboundManager](ctx)
+		if outboundManager == nil {
+			return nil, E.New("missing outbound manager")
+		}
+		dialer = NewDetour(outboundManager, options.Detour)
+	} else {
+		dialer, err = NewDefault(ctx, options)
+		if err != nil {
+			return nil, err
+		}
+	}
+	if remoteIsDomain {
+		var (
+			dnsQueryOptions      adapter.DNSQueryOptions
+			resolveFallbackDelay time.Duration
+		)
+		if options.DomainResolver == nil || options.DomainResolver.Server == "" {
+			return nil, E.New("missing domain resolver for domain server address")
+		}
+		var strategy C.DomainStrategy
+		if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
+			strategy = C.DomainStrategy(options.DomainResolver.Strategy)
+		} else if
+		//nolint:staticcheck
+		options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
+			//nolint:staticcheck
+			strategy = C.DomainStrategy(options.DomainStrategy)
+		}
+		dnsQueryOptions = adapter.DNSQueryOptions{
+			Strategy:     strategy,
+			DisableCache: options.DomainResolver.DisableCache,
+			RewriteTTL:   options.DomainResolver.RewriteTTL,
+			ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
+		}
+		resolveFallbackDelay = time.Duration(options.FallbackDelay)
+		dialer = NewResolveDialer(
+			ctx,
+			dialer,
+			options.Detour == "" && !options.TCPFastOpen,
+			options.DomainResolver.Server,
 			dnsQueryOptions,
 			resolveFallbackDelay,
 		)

+ 55 - 12
common/dialer/resolve.go

@@ -3,12 +3,14 @@ package dialer
 import (
 	"context"
 	"net"
+	"sync"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/service"
@@ -30,20 +32,26 @@ type ParallelInterfaceResolveDialer interface {
 }
 
 type resolveDialer struct {
+	transport     adapter.DNSTransportManager
 	router        adapter.DNSRouter
 	dialer        N.Dialer
 	parallel      bool
+	server        string
+	initOnce      sync.Once
+	initErr       error
 	queryOptions  adapter.DNSQueryOptions
 	fallbackDelay time.Duration
 }
 
-func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
+func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
 	return &resolveDialer{
-		service.FromContext[adapter.DNSRouter](ctx),
-		dialer,
-		parallel,
-		queryOptions,
-		fallbackDelay,
+		transport:     service.FromContext[adapter.DNSTransportManager](ctx),
+		router:        service.FromContext[adapter.DNSRouter](ctx),
+		dialer:        dialer,
+		parallel:      parallel,
+		server:        server,
+		queryOptions:  queryOptions,
+		fallbackDelay: fallbackDelay,
 	}
 }
 
@@ -52,20 +60,43 @@ type resolveParallelNetworkDialer struct {
 	dialer ParallelInterfaceDialer
 }
 
-func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
+func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
 	return &resolveParallelNetworkDialer{
 		resolveDialer{
-			service.FromContext[adapter.DNSRouter](ctx),
-			dialer,
-			parallel,
-			queryOptions,
-			fallbackDelay,
+			transport:     service.FromContext[adapter.DNSTransportManager](ctx),
+			router:        service.FromContext[adapter.DNSRouter](ctx),
+			dialer:        dialer,
+			parallel:      parallel,
+			server:        server,
+			queryOptions:  queryOptions,
+			fallbackDelay: fallbackDelay,
 		},
 		dialer,
 	}
 }
 
+func (d *resolveDialer) initialize() error {
+	d.initOnce.Do(d.initServer)
+	return d.initErr
+}
+
+func (d *resolveDialer) initServer() {
+	if d.server == "" {
+		return
+	}
+	transport, loaded := d.transport.Transport(d.server)
+	if !loaded {
+		d.initErr = E.New("domain resolver not found: " + d.server)
+		return
+	}
+	d.queryOptions.Transport = transport
+}
+
 func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	err := d.initialize()
+	if err != nil {
+		return nil, err
+	}
 	if !destination.IsFqdn() {
 		return d.dialer.DialContext(ctx, network, destination)
 	}
@@ -82,6 +113,10 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
 }
 
 func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	err := d.initialize()
+	if err != nil {
+		return nil, err
+	}
 	if !destination.IsFqdn() {
 		return d.dialer.ListenPacket(ctx, destination)
 	}
@@ -106,6 +141,10 @@ func (d *resolveDialer) Upstream() any {
 }
 
 func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
+	err := d.initialize()
+	if err != nil {
+		return nil, err
+	}
 	if !destination.IsFqdn() {
 		return d.dialer.DialContext(ctx, network, destination)
 	}
@@ -125,6 +164,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
 }
 
 func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
+	err := d.initialize()
+	if err != nil {
+		return nil, err
+	}
 	if !destination.IsFqdn() {
 		return d.dialer.ListenPacket(ctx, destination)
 	}

+ 10 - 2
dns/transport_adapter.go

@@ -27,9 +27,14 @@ func NewTransportAdapter(transportType string, transportTag string, dependencies
 }
 
 func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter {
+	var dependencies []string
+	if localOptions.DomainResolver != nil && localOptions.DomainResolver.Server != "" {
+		dependencies = append(dependencies, localOptions.DomainResolver.Server)
+	}
 	return TransportAdapter{
 		transportType: transportType,
 		transportTag:  transportTag,
+		dependencies:  dependencies,
 		strategy:      C.DomainStrategy(localOptions.LegacyStrategy),
 		clientSubnet:  localOptions.LegacyClientSubnet,
 	}
@@ -37,8 +42,11 @@ func NewTransportAdapterWithLocalOptions(transportType string, transportTag stri
 
 func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter {
 	var dependencies []string
-	if remoteOptions.AddressResolver != "" {
-		dependencies = []string{remoteOptions.AddressResolver}
+	if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" {
+		dependencies = append(dependencies, remoteOptions.DomainResolver.Server)
+	}
+	if remoteOptions.LegacyAddressResolver != "" {
+		dependencies = append(dependencies, remoteOptions.LegacyAddressResolver)
 	}
 	return TransportAdapter{
 		transportType: transportType,

+ 20 - 27
dns/transport_dialer.go

@@ -19,37 +19,30 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
 	if options.LegacyDefaultDialer {
 		return dialer.NewDefaultOutbound(ctx), nil
 	} else {
-		return dialer.New(ctx, options.DialerOptions, false)
+		return dialer.NewDNS(ctx, options.DialerOptions, false)
 	}
 }
 
 func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
-	var (
-		transportDialer N.Dialer
-		err             error
-	)
 	if options.LegacyDefaultDialer {
-		transportDialer = dialer.NewDefaultOutbound(ctx)
-	} else {
-		transportDialer, err = dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
-	}
-	if err != nil {
-		return nil, err
-	}
-	if options.AddressResolver != "" {
-		transport := service.FromContext[adapter.DNSTransportManager](ctx)
-		resolverTransport, loaded := transport.Transport(options.AddressResolver)
-		if !loaded {
-			return nil, E.New("address resolver not found: ", options.AddressResolver)
+		transportDialer := dialer.NewDefaultOutbound(ctx)
+		if options.LegacyAddressResolver != "" {
+			transport := service.FromContext[adapter.DNSTransportManager](ctx)
+			resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver)
+			if !loaded {
+				return nil, E.New("address resolver not found: ", options.LegacyAddressResolver)
+			}
+			transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay))
+		} else if options.ServerIsDomain() {
+			return nil, E.New("missing address resolver for server: ", options.Server)
 		}
-		transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay))
-	} else if options.ServerIsDomain() {
-		return nil, E.New("missing address resolver for server: ", options.Server)
+		return transportDialer, nil
+	} else {
+		return dialer.NewDNS(ctx, options.DialerOptions, options.ServerIsDomain())
 	}
-	return transportDialer, nil
 }
 
-type TransportDialer struct {
+type legacyTransportDialer struct {
 	dialer        N.Dialer
 	dnsRouter     adapter.DNSRouter
 	transport     adapter.DNSTransport
@@ -57,8 +50,8 @@ type TransportDialer struct {
 	fallbackDelay time.Duration
 }
 
-func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer {
-	return &TransportDialer{
+func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer {
+	return &legacyTransportDialer{
 		dialer,
 		dnsRouter,
 		transport,
@@ -67,7 +60,7 @@ func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport
 	}
 }
 
-func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
 	if destination.IsIP() {
 		return d.dialer.DialContext(ctx, network, destination)
 	}
@@ -81,7 +74,7 @@ func (d *TransportDialer) DialContext(ctx context.Context, network string, desti
 	return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
 }
 
-func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
 	if destination.IsIP() {
 		return d.dialer.ListenPacket(ctx, destination)
 	}
@@ -96,6 +89,6 @@ func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksa
 	return conn, err
 }
 
-func (d *TransportDialer) Upstream() any {
+func (d *legacyTransportDialer) Upstream() any {
 	return d.dialer
 }

+ 30 - 14
option/dns.go

@@ -128,18 +128,34 @@ func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error {
 	} else {
 		serverType = C.DNSTypeUDP
 	}
-	remoteOptions := RemoteDNSServerOptions{
-		LocalDNSServerOptions: LocalDNSServerOptions{
-			DialerOptions: DialerOptions{
-				Detour: options.Detour,
+	var remoteOptions RemoteDNSServerOptions
+	if options.Detour == "" {
+		remoteOptions = RemoteDNSServerOptions{
+			LocalDNSServerOptions: LocalDNSServerOptions{
+				LegacyStrategy:      options.Strategy,
+				LegacyDefaultDialer: options.Detour == "",
+				LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
 			},
-			LegacyStrategy:      options.Strategy,
-			LegacyDefaultDialer: options.Detour == "",
-			LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
-		},
-		AddressResolver:      options.AddressResolver,
-		AddressStrategy:      options.AddressStrategy,
-		AddressFallbackDelay: options.AddressFallbackDelay,
+			LegacyAddressResolver:      options.AddressResolver,
+			LegacyAddressStrategy:      options.AddressStrategy,
+			LegacyAddressFallbackDelay: options.AddressFallbackDelay,
+		}
+	} else {
+		remoteOptions = RemoteDNSServerOptions{
+			LocalDNSServerOptions: LocalDNSServerOptions{
+				DialerOptions: DialerOptions{
+					Detour: options.Detour,
+					DomainResolver: &DomainResolveOptions{
+						Server:   options.AddressResolver,
+						Strategy: options.AddressStrategy,
+					},
+					FallbackDelay: options.AddressFallbackDelay,
+				},
+				LegacyStrategy:      options.Strategy,
+				LegacyDefaultDialer: options.Detour == "",
+				LegacyClientSubnet:  options.ClientSubnet.Build(netip.Prefix{}),
+			},
+		}
 	}
 	switch serverType {
 	case C.DNSTypeUDP:
@@ -274,9 +290,9 @@ type LocalDNSServerOptions struct {
 type RemoteDNSServerOptions struct {
 	LocalDNSServerOptions
 	ServerOptions
-	AddressResolver      string             `json:"address_resolver,omitempty"`
-	AddressStrategy      DomainStrategy     `json:"address_strategy,omitempty"`
-	AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"`
+	LegacyAddressResolver      string             `json:"-"`
+	LegacyAddressStrategy      DomainStrategy     `json:"-"`
+	LegacyAddressFallbackDelay badoption.Duration `json:"-"`
 }
 
 type RemoteTLSDNSServerOptions struct {