浏览代码

Add `preferred_by` route rule item

世界 2 月之前
父节点
当前提交
3ff6df244c

+ 6 - 0
adapter/outbound.go

@@ -2,6 +2,7 @@ package adapter
 
 import (
 	"context"
+	"net/netip"
 
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
@@ -18,6 +19,11 @@ type Outbound interface {
 	N.Dialer
 }
 
+type OutboundWithPreferredRoutes interface {
+	PreferredDomain(domain string) bool
+	PreferredAddress(address netip.Addr) bool
+}
+
 type OutboundRegistry interface {
 	option.OutboundOptionsRegistry
 	CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error)

+ 1 - 0
option/rule.go

@@ -103,6 +103,7 @@ type RawDefaultRule struct {
 	InterfaceAddress         *badjson.TypedMap[string, badoption.Listable[badoption.Prefixable]]        `json:"interface_address,omitempty"`
 	NetworkInterfaceAddress  *badjson.TypedMap[InterfaceType, badoption.Listable[badoption.Prefixable]] `json:"network_interface_address,omitempty"`
 	DefaultInterfaceAddress  badoption.Listable[badoption.Prefixable]                                   `json:"default_interface_address,omitempty"`
+	PreferredBy              badoption.Listable[string]                                                 `json:"preferred_by,omitempty"`
 	RuleSet                  badoption.Listable[string]                                                 `json:"rule_set,omitempty"`
 	RuleSetIPCIDRMatchSource bool                                                                       `json:"rule_set_ip_cidr_match_source,omitempty"`
 	Invert                   bool                                                                       `json:"invert,omitempty"`

+ 2 - 13
protocol/tailscale/dns_transport.go

@@ -7,7 +7,6 @@ import (
 	"net/netip"
 	"net/url"
 	"os"
-	"reflect"
 	"strings"
 	"sync"
 
@@ -47,8 +46,6 @@ type DNSTransport struct {
 	acceptDefaultResolvers bool
 	dnsRouter              adapter.DNSRouter
 	endpointManager        adapter.EndpointManager
-	cfg                    *wgcfg.Config
-	dnsCfg                 *nDNS.Config
 	endpoint               *Endpoint
 	routePrefixes          []netip.Prefix
 	routes                 map[string][]adapter.DNSTransport
@@ -83,10 +80,10 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error {
 	if !isTailscale {
 		return E.New("endpoint is not Tailscale: ", t.endpointTag)
 	}
-	if ep.onReconfig != nil {
+	if ep.onReconfigHook != nil {
 		return E.New("only one Tailscale DNS server is allowed for single endpoint")
 	}
-	ep.onReconfig = t.onReconfig
+	ep.onReconfigHook = t.onReconfig
 	t.endpoint = ep
 	return nil
 }
@@ -95,14 +92,6 @@ func (t *DNSTransport) Reset() {
 }
 
 func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
-	if cfg == nil || dnsCfg == nil {
-		return
-	}
-	if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) {
-		return
-	}
-	t.cfg = cfg
-	t.dnsCfg = dnsCfg
 	err := t.updateDNSServers(routerCfg, dnsCfg)
 	if err != nil {
 		t.logger.Error(E.Cause(err, "update DNS servers"))

+ 62 - 5
protocol/tailscale/endpoint.go

@@ -10,6 +10,7 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"reflect"
 	"runtime"
 	"strings"
 	"sync/atomic"
@@ -49,8 +50,14 @@ import (
 	"github.com/sagernet/tailscale/version"
 	"github.com/sagernet/tailscale/wgengine"
 	"github.com/sagernet/tailscale/wgengine/filter"
+	"github.com/sagernet/tailscale/wgengine/router"
+	"github.com/sagernet/tailscale/wgengine/wgcfg"
+
+	"go4.org/netipx"
 )
 
+var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
+
 func init() {
 	version.SetVersion("sing-box " + C.Version)
 }
@@ -70,7 +77,12 @@ type Endpoint struct {
 	server            *tsnet.Server
 	stack             *stack.Stack
 	filter            *atomic.Pointer[filter.Filter]
-	onReconfig        wgengine.ReconfigListener
+	onReconfigHook    wgengine.ReconfigListener
+
+	cfg           *wgcfg.Config
+	dnsCfg        *tsDNS.Config
+	routeDomains  common.TypedValue[map[string]bool]
+	routePrefixes atomic.Pointer[netipx.IPSet]
 
 	acceptRoutes           bool
 	exitNode               string
@@ -216,9 +228,7 @@ func (t *Endpoint) Start(stage adapter.StartStage) error {
 	if err != nil {
 		return err
 	}
-	if t.onReconfig != nil {
-		t.server.ExportLocalBackend().ExportEngine().(wgengine.ExportedUserspaceEngine).SetOnReconfigListener(t.onReconfig)
-	}
+	t.server.ExportLocalBackend().ExportEngine().(wgengine.ExportedUserspaceEngine).SetOnReconfigListener(t.onReconfig)
 
 	ipStack := t.server.ExportNetstack().ExportIPStack()
 	gErr := ipStack.SetSpoofing(tun.DefaultNIC, true)
@@ -254,7 +264,6 @@ func (t *Endpoint) Start(stage adapter.StartStage) error {
 		return E.Cause(err, "update prefs")
 	}
 	t.filter = localBackend.ExportFilter()
-
 	go t.watchState()
 	return nil
 }
@@ -473,10 +482,58 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
 	t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
 }
 
+func (t *Endpoint) PreferredDomain(domain string) bool {
+	routeDomains := t.routeDomains.Load()
+	if routeDomains == nil {
+		return false
+	}
+	return routeDomains[strings.ToLower(domain)]
+}
+
+func (t *Endpoint) PreferredAddress(address netip.Addr) bool {
+	routePrefixes := t.routePrefixes.Load()
+	if routePrefixes == nil {
+		return false
+	}
+	return routePrefixes.Contains(address)
+}
+
 func (t *Endpoint) Server() *tsnet.Server {
 	return t.server
 }
 
+func (t *Endpoint) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *tsDNS.Config) {
+	if cfg == nil || dnsCfg == nil {
+		return
+	}
+	if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) {
+		return
+	}
+	t.cfg = cfg
+	t.dnsCfg = dnsCfg
+
+	routeDomains := make(map[string]bool)
+	for fqdn := range dnsCfg.Routes {
+		routeDomains[fqdn.WithoutTrailingDot()] = true
+	}
+	for _, fqdn := range dnsCfg.SearchDomains {
+		routeDomains[fqdn.WithoutTrailingDot()] = true
+	}
+	t.routeDomains.Store(routeDomains)
+
+	var builder netipx.IPSetBuilder
+	for _, peer := range cfg.Peers {
+		for _, allowedIP := range peer.AllowedIPs {
+			builder.AddPrefix(allowedIP)
+		}
+	}
+	t.routePrefixes.Store(common.Must1(builder.IPSet()))
+
+	if t.onReconfigHook != nil {
+		t.onReconfigHook(cfg, routerCfg, dnsCfg)
+	}
+}
+
 func addressFromAddr(destination netip.Addr) tcpip.Address {
 	if destination.Is6() {
 		return tcpip.AddrFrom16(destination.As16())

+ 10 - 0
protocol/wireguard/endpoint.go

@@ -22,6 +22,8 @@ import (
 	"github.com/sagernet/sing/service"
 )
 
+var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
+
 func RegisterEndpoint(registry *endpoint.Registry) {
 	endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
 }
@@ -210,3 +212,11 @@ func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
 	}
 	return w.endpoint.ListenPacket(ctx, destination)
 }
+
+func (w *Endpoint) PreferredDomain(domain string) bool {
+	return false
+}
+
+func (w *Endpoint) PreferredAddress(address netip.Addr) bool {
+	return w.endpoint.Lookup(address) != nil
+}

+ 10 - 0
protocol/wireguard/outbound.go

@@ -21,6 +21,8 @@ import (
 	"github.com/sagernet/sing/service"
 )
 
+var _ adapter.OutboundWithPreferredRoutes = (*Outbound)(nil)
+
 func RegisterOutbound(registry *outbound.Registry) {
 	outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
 }
@@ -158,3 +160,11 @@ func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
 	}
 	return o.endpoint.ListenPacket(ctx, destination)
 }
+
+func (o *Outbound) PreferredDomain(domain string) bool {
+	return false
+}
+
+func (o *Outbound) PreferredAddress(address netip.Addr) bool {
+	return o.endpoint.Lookup(address) != nil
+}

+ 6 - 1
route/rule/rule_default.go

@@ -117,7 +117,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
 	if len(options.DomainRegex) > 0 {
 		item, err := NewDomainRegexItem(options.DomainRegex)
 		if err != nil {
-			return nil, E.Cause(err, "domain_regex")
+			return nil, err
 		}
 		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
@@ -261,6 +261,11 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
 		rule.items = append(rule.items, item)
 		rule.allItems = append(rule.allItems, item)
 	}
+	if len(options.PreferredBy) > 0 {
+		item := NewPreferredByItem(ctx, options.PreferredBy)
+		rule.items = append(rule.items, item)
+		rule.allItems = append(rule.allItems, item)
+	}
 	if len(options.RuleSet) > 0 {
 		var matchSource bool
 		if options.RuleSetIPCIDRMatchSource {

+ 86 - 0
route/rule/rule_item_preferred_by.go

@@ -0,0 +1,86 @@
+package rule
+
+import (
+	"context"
+	"strings"
+
+	"github.com/sagernet/sing-box/adapter"
+	E "github.com/sagernet/sing/common/exceptions"
+	F "github.com/sagernet/sing/common/format"
+	"github.com/sagernet/sing/service"
+)
+
+var _ RuleItem = (*PreferredByItem)(nil)
+
+type PreferredByItem struct {
+	ctx          context.Context
+	outboundTags []string
+	outbounds    []adapter.OutboundWithPreferredRoutes
+}
+
+func NewPreferredByItem(ctx context.Context, outboundTags []string) *PreferredByItem {
+	return &PreferredByItem{
+		ctx:          ctx,
+		outboundTags: outboundTags,
+	}
+}
+
+func (r *PreferredByItem) Start() error {
+	outboundManager := service.FromContext[adapter.OutboundManager](r.ctx)
+	for _, outboundTag := range r.outboundTags {
+		rawOutbound, loaded := outboundManager.Outbound(outboundTag)
+		if !loaded {
+			return E.New("outbound not found: ", outboundTag)
+		}
+		outboundWithPreferredRoutes, withRoutes := rawOutbound.(adapter.OutboundWithPreferredRoutes)
+		if !withRoutes {
+			return E.New("outbound type does not support preferred routes: ", rawOutbound.Type())
+		}
+		r.outbounds = append(r.outbounds, outboundWithPreferredRoutes)
+	}
+	return nil
+}
+
+func (r *PreferredByItem) Match(metadata *adapter.InboundContext) bool {
+	var domainHost string
+	if metadata.Domain != "" {
+		domainHost = metadata.Domain
+	} else {
+		domainHost = metadata.Destination.Fqdn
+	}
+	if domainHost != "" {
+		for _, outbound := range r.outbounds {
+			if outbound.PreferredDomain(domainHost) {
+				return true
+			}
+		}
+	}
+	if metadata.Destination.IsIP() {
+		for _, outbound := range r.outbounds {
+			if outbound.PreferredAddress(metadata.Destination.Addr) {
+				return true
+			}
+		}
+	}
+	if len(metadata.DestinationAddresses) > 0 {
+		for _, address := range metadata.DestinationAddresses {
+			for _, outbound := range r.outbounds {
+				if outbound.PreferredAddress(address) {
+					return true
+				}
+			}
+		}
+	}
+	return false
+}
+
+func (r *PreferredByItem) String() string {
+	description := "preferred_by="
+	pLen := len(r.outboundTags)
+	if pLen == 1 {
+		description += F.ToString(r.outboundTags[0])
+	} else {
+		description += "[" + strings.Join(F.MapToString(r.outboundTags), " ") + "]"
+	}
+	return description
+}

+ 11 - 0
transport/wireguard/endpoint.go

@@ -8,7 +8,9 @@ import (
 	"net"
 	"net/netip"
 	"os"
+	"reflect"
 	"strings"
+	"unsafe"
 
 	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -30,6 +32,7 @@ type Endpoint struct {
 	allowedAddress []netip.Prefix
 	tunDevice      Device
 	device         *device.Device
+	allowedIPs     *device.AllowedIPs
 	pause          pause.Manager
 	pauseCallback  *list.Element[pause.Callback]
 }
@@ -191,6 +194,7 @@ func (e *Endpoint) Start(resolve bool) error {
 	if e.pause != nil {
 		e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
 	}
+	e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
 	return nil
 }
 
@@ -218,6 +222,13 @@ func (e *Endpoint) Close() error {
 	return nil
 }
 
+func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
+	if e.allowedIPs == nil {
+		return nil
+	}
+	return e.allowedIPs.Lookup(address.AsSlice())
+}
+
 func (e *Endpoint) onPauseUpdated(event int) {
 	switch event {
 	case pause.EventDevicePaused, pause.EventNetworkPause: