Browse Source

Explicitly reject detour to empty direct outbounds

世界 7 months ago
parent
commit
ef92ed6795

+ 1 - 1
adapter/dns.go

@@ -45,10 +45,10 @@ type RDRCStore interface {
 }
 
 type DNSTransport interface {
+	Lifecycle
 	Type() string
 	Tag() string
 	Dependencies() []string
-	Reset()
 	Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
 }
 

+ 21 - 5
common/dialer/detour.go

@@ -6,14 +6,20 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 )
 
+type DirectDialer interface {
+	IsEmpty() bool
+}
+
 type DetourDialer struct {
 	outboundManager adapter.OutboundManager
 	detour          string
+	directResolver  bool
 	dialer          N.Dialer
 	initOnce        sync.Once
 	initErr         error
@@ -23,9 +29,12 @@ func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer
 	return &DetourDialer{outboundManager: outboundManager, detour: detour}
 }
 
-func (d *DetourDialer) Start() error {
-	_, err := d.Dialer()
-	return err
+func InitializeDetour(dialer N.Dialer) error {
+	detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
+	if !isDetour {
+		return nil
+	}
+	return common.Error(detourDialer.Dialer())
 }
 
 func (d *DetourDialer) Dialer() (N.Dialer, error) {
@@ -34,11 +43,18 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) {
 }
 
 func (d *DetourDialer) init() {
-	var loaded bool
-	d.dialer, loaded = d.outboundManager.Outbound(d.detour)
+	dialer, loaded := d.outboundManager.Outbound(d.detour)
 	if !loaded {
 		d.initErr = E.New("outbound detour not found: ", d.detour)
+		return
+	}
+	if directDialer, isDirect := dialer.(DirectDialer); isDirect {
+		if directDialer.IsEmpty() {
+			d.initErr = E.New("detour to an empty direct outbound makes no sense")
+			return
+		}
 	}
+	d.dialer = dialer
 }
 
 func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

+ 1 - 1
dns/router.go

@@ -449,6 +449,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
 func (r *Router) ResetNetwork() {
 	r.ClearCache()
 	for _, transport := range r.transport.Transports() {
-		transport.Reset()
+		transport.Close()
 	}
 }

+ 2 - 8
dns/transport/dhcp/dhcp.go

@@ -81,7 +81,7 @@ func (t *Transport) Start(stage adapter.StartStage) error {
 
 func (t *Transport) Close() error {
 	for _, transport := range t.transports {
-		transport.Reset()
+		transport.Close()
 	}
 	if t.interfaceCallback != nil {
 		t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
@@ -89,12 +89,6 @@ func (t *Transport) Close() error {
 	return nil
 }
 
-func (t *Transport) Reset() {
-	for _, transport := range t.transports {
-		transport.Reset()
-	}
-}
-
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	err := t.fetchServers()
 	if err != nil {
@@ -252,7 +246,7 @@ func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.So
 		transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr))
 	}
 	for _, transport := range t.transports {
-		transport.Reset()
+		transport.Close()
 	}
 	t.transports = transports
 	return nil

+ 6 - 1
dns/transport/hosts/hosts.go

@@ -51,7 +51,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 10 - 1
dns/transport/https.go

@@ -10,6 +10,7 @@ import (
 	"strconv"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -149,9 +150,17 @@ func NewHTTPSRaw(
 	}
 }
 
-func (t *HTTPSTransport) Reset() {
+func (t *HTTPSTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *HTTPSTransport) Close() error {
 	t.transport.CloseIdleConnections()
 	t.transport = t.transport.Clone()
+	return nil
 }
 
 func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 6 - 1
dns/transport/local/local.go

@@ -40,7 +40,12 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 6 - 2
dns/transport/quic/http3.go

@@ -111,8 +111,12 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
 	}, nil
 }
 
-func (t *HTTP3Transport) Reset() {
-	t.transport.Close()
+func (t *HTTP3Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *HTTP3Transport) Close() error {
+	return t.transport.Close()
 }
 
 func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 6 - 1
dns/transport/quic/quic.go

@@ -68,13 +68,18 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
 	}, nil
 }
 
-func (t *Transport) Reset() {
+func (t *Transport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (t *Transport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	connection := t.connection
 	if connection != nil {
 		connection.CloseWithError(0, "")
 	}
+	return nil
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 10 - 1
dns/transport/tcp.go

@@ -6,6 +6,7 @@ import (
 	"io"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
 	"github.com/sagernet/sing-box/log"
@@ -46,7 +47,15 @@ func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options o
 	}, nil
 }
 
-func (t *TCPTransport) Reset() {
+func (t *TCPTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *TCPTransport) Close() error {
+	return nil
 }
 
 func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 10 - 1
dns/transport/tls.go

@@ -5,6 +5,7 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -65,13 +66,21 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
 	}, nil
 }
 
-func (t *TLSTransport) Reset() {
+func (t *TLSTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *TLSTransport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
 		connection.Value.Close()
 	}
 	t.connections.Init()
+	return nil
 }
 
 func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 10 - 1
dns/transport/udp.go

@@ -7,6 +7,7 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
 	"github.com/sagernet/sing-box/log"
@@ -64,11 +65,19 @@ func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer
 	}
 }
 
-func (t *UDPTransport) Reset() {
+func (t *UDPTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
+}
+
+func (t *UDPTransport) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 	close(t.done)
 	t.done = make(chan struct{})
+	return nil
 }
 
 func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 1 - 1
dns/transport_manager.go

@@ -225,7 +225,7 @@ func (m *TransportManager) Remove(tag string) error {
 		}
 	}
 	if started {
-		transport.Reset()
+		transport.Close()
 	}
 	return nil
 }

+ 6 - 1
experimental/libbox/dns.go

@@ -38,7 +38,12 @@ func newPlatformTransport(iif LocalDNSTransport, tag string, options option.Loca
 	}
 }
 
-func (p *platformTransport) Reset() {
+func (p *platformTransport) Start(stage adapter.StartStage) error {
+	return nil
+}
+
+func (p *platformTransport) Close() error {
+	return nil
 }
 
 func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {

+ 1 - 2
option/rule.go

@@ -125,10 +125,9 @@ func (r *DefaultRule) UnmarshalJSON(data []byte) error {
 	return badjson.UnmarshallExcluded(data, &r.RawDefaultRule, &r.RuleAction)
 }
 
-func (r *DefaultRule) IsValid() bool {
+func (r DefaultRule) IsValid() bool {
 	var defaultValue DefaultRule
 	defaultValue.Invert = r.Invert
-	defaultValue.Action = r.Action
 	return !reflect.DeepEqual(r, defaultValue)
 }
 

+ 0 - 1
option/rule_dns.go

@@ -132,7 +132,6 @@ func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte)
 func (r DefaultDNSRule) IsValid() bool {
 	var defaultValue DefaultDNSRule
 	defaultValue.Invert = r.Invert
-	defaultValue.DNSRuleAction = r.DNSRuleAction
 	return !reflect.DeepEqual(r, defaultValue)
 }
 

+ 9 - 0
protocol/direct/outbound.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"net/netip"
+	"reflect"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
@@ -27,6 +28,7 @@ func RegisterOutbound(registry *outbound.Registry) {
 var (
 	_ N.ParallelDialer             = (*Outbound)(nil)
 	_ dialer.ParallelNetworkDialer = (*Outbound)(nil)
+	_ dialer.DirectDialer          = (*Outbound)(nil)
 )
 
 type Outbound struct {
@@ -37,6 +39,7 @@ type Outbound struct {
 	fallbackDelay       time.Duration
 	overrideOption      int
 	overrideDestination M.Socksaddr
+	isEmpty             bool
 	// loopBack *loopBackDetector
 }
 
@@ -56,6 +59,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 		domainStrategy: C.DomainStrategy(options.DomainStrategy),
 		fallbackDelay:  time.Duration(options.FallbackDelay),
 		dialer:         outboundDialer.(dialer.ParallelInterfaceDialer),
+		//nolint:staticcheck
+		isEmpty: reflect.DeepEqual(options.DialerOptions, option.DialerOptions{UDPFragmentDefault: true}) && options.OverrideAddress == "" && options.OverridePort == 0,
 		// loopBack:       newLoopBackDetector(router),
 	}
 	//nolint:staticcheck
@@ -242,6 +247,10 @@ func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.
 	return conn, newDestination, nil
 }
 
+func (h *Outbound) IsEmpty() bool {
+	return h.isEmpty
+}
+
 /*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 	if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
 		return E.New("reject loopback connection to ", metadata.Destination)