Explorar o código

Improve rule actions

世界 hai 11 meses
pai
achega
313be3d7a4

+ 3 - 1
adapter/inbound.go

@@ -57,7 +57,9 @@ type InboundContext struct {
 	// Deprecated
 	InboundOptions            option.InboundOptions
 	UDPDisableDomainUnmapping bool
-	DNSServer                 string
+	UDPConnect                bool
+
+	DNSServer string
 
 	DestinationAddresses []netip.Addr
 	SourceGeoIPCode      string

+ 7 - 6
constant/rule.go

@@ -25,12 +25,13 @@ const (
 )
 
 const (
-	RuleActionTypeRoute     = "route"
-	RuleActionTypeReturn    = "return"
-	RuleActionTypeReject    = "reject"
-	RuleActionTypeHijackDNS = "hijack-dns"
-	RuleActionTypeSniff     = "sniff"
-	RuleActionTypeResolve   = "resolve"
+	RuleActionTypeRoute        = "route"
+	RuleActionTypeRouteOptions = "route-options"
+	RuleActionTypeDirect       = "direct"
+	RuleActionTypeReject       = "reject"
+	RuleActionTypeHijackDNS    = "hijack-dns"
+	RuleActionTypeSniff        = "sniff"
+	RuleActionTypeResolve      = "resolve"
 )
 
 const (

+ 7 - 7
option/rule.go

@@ -109,7 +109,7 @@ type DefaultRule struct {
 	RuleAction
 }
 
-func (r *DefaultRule) MarshalJSON() ([]byte, error) {
+func (r DefaultRule) MarshalJSON() ([]byte, error) {
 	return badjson.MarshallObjects(r.RawDefaultRule, r.RuleAction)
 }
 
@@ -128,27 +128,27 @@ func (r *DefaultRule) IsValid() bool {
 	return !reflect.DeepEqual(r, defaultValue)
 }
 
-type _LogicalRule struct {
+type RawLogicalRule struct {
 	Mode   string `json:"mode"`
 	Rules  []Rule `json:"rules,omitempty"`
 	Invert bool   `json:"invert,omitempty"`
 }
 
 type LogicalRule struct {
-	_LogicalRule
+	RawLogicalRule
 	RuleAction
 }
 
-func (r *LogicalRule) MarshalJSON() ([]byte, error) {
-	return badjson.MarshallObjects(r._LogicalRule, r.RuleAction)
+func (r LogicalRule) MarshalJSON() ([]byte, error) {
+	return badjson.MarshallObjects(r.RawLogicalRule, r.RuleAction)
 }
 
 func (r *LogicalRule) UnmarshalJSON(data []byte) error {
-	err := json.Unmarshal(data, &r._LogicalRule)
+	err := json.Unmarshal(data, &r.RawLogicalRule)
 	if err != nil {
 		return err
 	}
-	return badjson.UnmarshallExcluded(data, &r._LogicalRule, &r.RuleAction)
+	return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction)
 }
 
 func (r *LogicalRule) IsValid() bool {

+ 157 - 25
option/rule_action.go

@@ -1,30 +1,41 @@
 package option
 
 import (
+	"fmt"
+	"time"
+
 	C "github.com/sagernet/sing-box/constant"
+	dns "github.com/sagernet/sing-dns"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/json"
 	"github.com/sagernet/sing/common/json/badjson"
 )
 
 type _RuleAction struct {
-	Action         string              `json:"action,omitempty"`
-	RouteOptions   RouteActionOptions  `json:"-"`
-	RejectOptions  RejectActionOptions `json:"-"`
-	SniffOptions   RouteActionSniff    `json:"-"`
-	ResolveOptions RouteActionResolve  `json:"-"`
+	Action              string                    `json:"action,omitempty"`
+	RouteOptions        RouteActionOptions        `json:"-"`
+	RouteOptionsOptions RouteOptionsActionOptions `json:"-"`
+	DirectOptions       DirectActionOptions       `json:"-"`
+	RejectOptions       RejectActionOptions       `json:"-"`
+	SniffOptions        RouteActionSniff          `json:"-"`
+	ResolveOptions      RouteActionResolve        `json:"-"`
 }
 
 type RuleAction _RuleAction
 
 func (r RuleAction) MarshalJSON() ([]byte, error) {
+	if r.Action == "" {
+		return json.Marshal(struct{}{})
+	}
 	var v any
 	switch r.Action {
 	case C.RuleActionTypeRoute:
 		r.Action = ""
 		v = r.RouteOptions
-	case C.RuleActionTypeReturn:
-		v = nil
+	case C.RuleActionTypeRouteOptions:
+		v = r.RouteOptionsOptions
+	case C.RuleActionTypeDirect:
+		v = r.DirectOptions
 	case C.RuleActionTypeReject:
 		v = r.RejectOptions
 	case C.RuleActionTypeHijackDNS:
@@ -52,8 +63,10 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error {
 	case "", C.RuleActionTypeRoute:
 		r.Action = C.RuleActionTypeRoute
 		v = &r.RouteOptions
-	case C.RuleActionTypeReturn:
-		v = nil
+	case C.RuleActionTypeRouteOptions:
+		v = &r.RouteOptionsOptions
+	case C.RuleActionTypeDirect:
+		v = &r.DirectOptions
 	case C.RuleActionTypeReject:
 		v = &r.RejectOptions
 	case C.RuleActionTypeHijackDNS:
@@ -73,29 +86,30 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error {
 }
 
 type _DNSRuleAction struct {
-	Action        string                `json:"action,omitempty"`
-	RouteOptions  DNSRouteActionOptions `json:"-"`
-	RejectOptions RejectActionOptions   `json:"-"`
+	Action              string                       `json:"action,omitempty"`
+	RouteOptions        DNSRouteActionOptions        `json:"-"`
+	RouteOptionsOptions DNSRouteOptionsActionOptions `json:"-"`
+	RejectOptions       RejectActionOptions          `json:"-"`
 }
 
 type DNSRuleAction _DNSRuleAction
 
 func (r DNSRuleAction) MarshalJSON() ([]byte, error) {
+	if r.Action == "" {
+		return json.Marshal(struct{}{})
+	}
 	var v any
 	switch r.Action {
 	case C.RuleActionTypeRoute:
 		r.Action = ""
 		v = r.RouteOptions
-	case C.RuleActionTypeReturn:
-		v = nil
+	case C.RuleActionTypeRouteOptions:
+		v = r.RouteOptionsOptions
 	case C.RuleActionTypeReject:
 		v = r.RejectOptions
 	default:
 		return nil, E.New("unknown DNS rule action: " + r.Action)
 	}
-	if v == nil {
-		return badjson.MarshallObjects((_DNSRuleAction)(r))
-	}
 	return badjson.MarshallObjects((_DNSRuleAction)(r), v)
 }
 
@@ -109,8 +123,8 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error {
 	case "", C.RuleActionTypeRoute:
 		r.Action = C.RuleActionTypeRoute
 		v = &r.RouteOptions
-	case C.RuleActionTypeReturn:
-		v = nil
+	case C.RuleActionTypeRouteOptions:
+		v = &r.RouteOptionsOptions
 	case C.RuleActionTypeReject:
 		v = &r.RejectOptions
 	default:
@@ -123,18 +137,136 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error {
 	return badjson.UnmarshallExcluded(data, (*_DNSRuleAction)(r), v)
 }
 
-type RouteActionOptions struct {
-	Outbound                  string `json:"outbound"`
-	UDPDisableDomainUnmapping bool   `json:"udp_disable_domain_unmapping,omitempty"`
+type _RouteActionOptions struct {
+	Outbound string `json:"outbound,omitempty"`
+}
+
+type RouteActionOptions _RouteActionOptions
+
+func (r *RouteActionOptions) UnmarshalJSON(data []byte) error {
+	err := json.Unmarshal(data, (*_RouteActionOptions)(r))
+	if err != nil {
+		return err
+	}
+	if r.Outbound == "" {
+		return E.New("missing outbound")
+	}
+	return nil
+}
+
+type _RouteOptionsActionOptions struct {
+	UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"`
+	UDPConnect                bool `json:"udp_connect,omitempty"`
+}
+
+type RouteOptionsActionOptions _RouteOptionsActionOptions
+
+func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
+	err := json.Unmarshal(data, (*_RouteOptionsActionOptions)(r))
+	if err != nil {
+		return err
+	}
+	if *r == (RouteOptionsActionOptions{}) {
+		return E.New("empty route option action")
+	}
+	return nil
+}
+
+type _DNSRouteActionOptions struct {
+	Server string `json:"server,omitempty"`
+	// Deprecated: Use DNSRouteOptionsActionOptions instead.
+	DisableCache bool `json:"disable_cache,omitempty"`
+	// Deprecated: Use DNSRouteOptionsActionOptions instead.
+	RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
+	// Deprecated: Use DNSRouteOptionsActionOptions instead.
+	ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"`
+}
+
+type DNSRouteActionOptions _DNSRouteActionOptions
+
+func (r *DNSRouteActionOptions) UnmarshalJSON(data []byte) error {
+	err := json.Unmarshal(data, (*_DNSRouteActionOptions)(r))
+	if err != nil {
+		return err
+	}
+	if r.Server == "" {
+		return E.New("missing server")
+	}
+	return nil
 }
 
-type DNSRouteActionOptions struct {
-	Server       string      `json:"server"`
+type _DNSRouteOptionsActionOptions struct {
 	DisableCache bool        `json:"disable_cache,omitempty"`
 	RewriteTTL   *uint32     `json:"rewrite_ttl,omitempty"`
 	ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"`
 }
 
+type DNSRouteOptionsActionOptions _DNSRouteOptionsActionOptions
+
+func (r *DNSRouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
+	err := json.Unmarshal(data, (*_DNSRouteOptionsActionOptions)(r))
+	if err != nil {
+		return err
+	}
+	if *r == (DNSRouteOptionsActionOptions{}) {
+		return E.New("empty DNS route option action")
+	}
+	return nil
+}
+
+type _DirectActionOptions DialerOptions
+
+type DirectActionOptions _DirectActionOptions
+
+func (d DirectActionOptions) Descriptions() []string {
+	var descriptions []string
+	if d.BindInterface != "" {
+		descriptions = append(descriptions, "bind_interface="+d.BindInterface)
+	}
+	if d.Inet4BindAddress != nil {
+		descriptions = append(descriptions, "inet4_bind_address="+d.Inet4BindAddress.Build().String())
+	}
+	if d.Inet6BindAddress != nil {
+		descriptions = append(descriptions, "inet6_bind_address="+d.Inet6BindAddress.Build().String())
+	}
+	if d.RoutingMark != 0 {
+		descriptions = append(descriptions, "routing_mark="+fmt.Sprintf("0x%x", d.RoutingMark))
+	}
+	if d.ReuseAddr {
+		descriptions = append(descriptions, "reuse_addr")
+	}
+	if d.ConnectTimeout != 0 {
+		descriptions = append(descriptions, "connect_timeout="+time.Duration(d.ConnectTimeout).String())
+	}
+	if d.TCPFastOpen {
+		descriptions = append(descriptions, "tcp_fast_open")
+	}
+	if d.TCPMultiPath {
+		descriptions = append(descriptions, "tcp_multi_path")
+	}
+	if d.UDPFragment != nil {
+		descriptions = append(descriptions, "udp_fragment="+fmt.Sprint(*d.UDPFragment))
+	}
+	if d.DomainStrategy != DomainStrategy(dns.DomainStrategyAsIS) {
+		descriptions = append(descriptions, "domain_strategy="+d.DomainStrategy.String())
+	}
+	if d.FallbackDelay != 0 {
+		descriptions = append(descriptions, "fallback_delay="+time.Duration(d.FallbackDelay).String())
+	}
+	return descriptions
+}
+
+func (d *DirectActionOptions) UnmarshalJSON(data []byte) error {
+	err := json.Unmarshal(data, (*_DirectActionOptions)(d))
+	if err != nil {
+		return err
+	}
+	if d.Detour != "" {
+		return E.New("detour is not available in the current context")
+	}
+	return nil
+}
+
 type _RejectActionOptions struct {
 	Method string `json:"method,omitempty"`
 	NoDrop bool   `json:"no_drop,omitempty"`
@@ -155,7 +287,7 @@ func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error {
 		return E.New("unknown reject method: " + r.Method)
 	}
 	if r.Method == C.RuleActionRejectMethodDrop && r.NoDrop {
-		return E.New("no_drop is not allowed when method is drop")
+		return E.New("no_drop is not available in current context")
 	}
 	return nil
 }

+ 8 - 8
option/rule_dns.go

@@ -111,7 +111,7 @@ type DefaultDNSRule struct {
 	DNSRuleAction
 }
 
-func (r *DefaultDNSRule) MarshalJSON() ([]byte, error) {
+func (r DefaultDNSRule) MarshalJSON() ([]byte, error) {
 	return badjson.MarshallObjects(r.RawDefaultDNSRule, r.DNSRuleAction)
 }
 
@@ -123,34 +123,34 @@ func (r *DefaultDNSRule) UnmarshalJSON(data []byte) error {
 	return badjson.UnmarshallExcluded(data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
 }
 
-func (r *DefaultDNSRule) IsValid() bool {
+func (r DefaultDNSRule) IsValid() bool {
 	var defaultValue DefaultDNSRule
 	defaultValue.Invert = r.Invert
 	defaultValue.DNSRuleAction = r.DNSRuleAction
 	return !reflect.DeepEqual(r, defaultValue)
 }
 
-type _LogicalDNSRule struct {
+type RawLogicalDNSRule struct {
 	Mode   string    `json:"mode"`
 	Rules  []DNSRule `json:"rules,omitempty"`
 	Invert bool      `json:"invert,omitempty"`
 }
 
 type LogicalDNSRule struct {
-	_LogicalDNSRule
+	RawLogicalDNSRule
 	DNSRuleAction
 }
 
-func (r *LogicalDNSRule) MarshalJSON() ([]byte, error) {
-	return badjson.MarshallObjects(r._LogicalDNSRule, r.DNSRuleAction)
+func (r LogicalDNSRule) MarshalJSON() ([]byte, error) {
+	return badjson.MarshallObjects(r.RawLogicalDNSRule, r.DNSRuleAction)
 }
 
 func (r *LogicalDNSRule) UnmarshalJSON(data []byte) error {
-	err := json.Unmarshal(data, &r._LogicalDNSRule)
+	err := json.Unmarshal(data, &r.RawLogicalDNSRule)
 	if err != nil {
 		return err
 	}
-	return badjson.UnmarshallExcluded(data, &r._LogicalDNSRule, &r.DNSRuleAction)
+	return badjson.UnmarshallExcluded(data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
 }
 
 func (r *LogicalDNSRule) IsValid() bool {

+ 1 - 0
protocol/dns/handle.go

@@ -43,6 +43,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
 	go func() error {
 		response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
 		if err != nil {
+			conn.Close()
 			return err
 		}
 		responseBuffer := buf.NewPacket()

+ 72 - 58
route/route.go

@@ -87,23 +87,34 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 	if deadline.NeedAdditionalReadDeadline(conn) {
 		conn = deadline.NewConn(conn)
 	}
-	selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
+	selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil)
 	if err != nil {
 		return err
 	}
-	var selectedOutbound adapter.Outbound
-	var selectReturn bool
+	var (
+		// selectedOutbound adapter.Outbound
+		selectedDialer      N.Dialer
+		selectedTag         string
+		selectedDescription string
+	)
 	if selectedRule != nil {
 		switch action := selectedRule.Action().(type) {
 		case *rule.RuleActionRoute:
-			var loaded bool
-			selectedOutbound, loaded = r.Outbound(action.Outbound)
+			selectedOutbound, loaded := r.Outbound(action.Outbound)
 			if !loaded {
 				buf.ReleaseMulti(buffers)
 				return E.New("outbound not found: ", action.Outbound)
 			}
-		case *rule.RuleActionReturn:
-			selectReturn = true
+			if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) {
+				buf.ReleaseMulti(buffers)
+				return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
+			}
+			selectedDialer = selectedOutbound
+			selectedTag = selectedOutbound.Tag()
+			selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+		case *rule.RuleActionDirect:
+			selectedDialer = action.Dialer
+			selectedDescription = action.String()
 		case *rule.RuleActionReject:
 			buf.ReleaseMulti(buffers)
 			N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@@ -116,17 +127,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 			return nil
 		}
 	}
-	if selectedRule == nil || selectReturn {
+	if selectedRule == nil {
 		if r.defaultOutboundForConnection == nil {
 			buf.ReleaseMulti(buffers)
 			return E.New("missing default outbound with TCP support")
 		}
-		selectedOutbound = r.defaultOutboundForConnection
-	}
-	if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) {
-		buf.ReleaseMulti(buffers)
-		return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
+		selectedDialer = r.defaultOutboundForConnection
+		selectedTag = r.defaultOutboundForConnection.Tag()
+		selectedDescription = F.ToString("outbound/", r.defaultOutboundForConnection.Type(), "[", r.defaultOutboundForConnection.Tag(), "]")
 	}
+
 	for _, buffer := range buffers {
 		conn = bufio.NewCachedConn(conn, buffer)
 	}
@@ -137,10 +147,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 	}
 	if r.v2rayServer != nil {
 		if statsService := r.v2rayServer.StatsService(); statsService != nil {
-			conn = statsService.RoutedConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn)
+			conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn)
 		}
 	}
-	legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler)
+	legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler)
 	if isLegacy {
 		err = legacyOutbound.NewConnection(ctx, conn, metadata)
 		if err != nil {
@@ -148,7 +158,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 			if onClose != nil {
 				onClose(err)
 			}
-			return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+			return E.Cause(err, selectedDescription)
 		} else {
 			if onClose != nil {
 				onClose(nil)
@@ -157,13 +167,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 		return nil
 	}
 	// TODO
-	err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
+	err = outbound.NewConnection(ctx, selectedDialer, conn, metadata)
 	if err != nil {
 		conn.Close()
 		if onClose != nil {
 			onClose(err)
 		}
-		return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+		return E.Cause(err, selectedDescription)
 	} else {
 		if onClose != nil {
 			onClose(nil)
@@ -231,24 +241,34 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
 	}*/
 
-	selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
+	selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn)
 	if err != nil {
 		return err
 	}
-	var selectedOutbound adapter.Outbound
+	var (
+		selectedDialer      N.Dialer
+		selectedTag         string
+		selectedDescription string
+	)
 	var selectReturn bool
 	if selectedRule != nil {
 		switch action := selectedRule.Action().(type) {
 		case *rule.RuleActionRoute:
-			var loaded bool
-			selectedOutbound, loaded = r.Outbound(action.Outbound)
+			selectedOutbound, loaded := r.Outbound(action.Outbound)
 			if !loaded {
 				N.ReleaseMultiPacketBuffer(packetBuffers)
 				return E.New("outbound not found: ", action.Outbound)
 			}
-			metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
-		case *rule.RuleActionReturn:
-			selectReturn = true
+			if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
+				N.ReleaseMultiPacketBuffer(packetBuffers)
+				return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
+			}
+			selectedDialer = selectedOutbound
+			selectedTag = selectedOutbound.Tag()
+			selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+		case *rule.RuleActionDirect:
+			selectedDialer = action.Dialer
+			selectedDescription = action.String()
 		case *rule.RuleActionReject:
 			N.ReleaseMultiPacketBuffer(packetBuffers)
 			N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@@ -263,11 +283,9 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 			N.ReleaseMultiPacketBuffer(packetBuffers)
 			return E.New("missing default outbound with UDP support")
 		}
-		selectedOutbound = r.defaultOutboundForPacketConnection
-	}
-	if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
-		N.ReleaseMultiPacketBuffer(packetBuffers)
-		return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
+		selectedDialer = r.defaultOutboundForPacketConnection
+		selectedTag = r.defaultOutboundForPacketConnection.Tag()
+		selectedDescription = F.ToString("outbound/", r.defaultOutboundForPacketConnection.Type(), "[", r.defaultOutboundForPacketConnection.Tag(), "]")
 	}
 	for _, buffer := range packetBuffers {
 		conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
@@ -280,32 +298,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 	}
 	if r.v2rayServer != nil {
 		if statsService := r.v2rayServer.StatsService(); statsService != nil {
-			conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn)
+			conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn)
 		}
 	}
 	if metadata.FakeIP {
 		conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
 	}
-	legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler)
+	legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler)
 	if isLegacy {
 		err = legacyOutbound.NewPacketConnection(ctx, conn, metadata)
 		N.CloseOnHandshakeFailure(conn, onClose, err)
 		if err != nil {
-			return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+			return E.Cause(err, selectedDescription)
 		}
 		return nil
 	}
 	// TODO
-	err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
+	err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata)
 	N.CloseOnHandshakeFailure(conn, onClose, err)
 	if err != nil {
-		return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
+		return E.Cause(err, selectedDescription)
 	}
 	return nil
 }
 
 func (r *Router) PreMatch(metadata adapter.InboundContext) error {
-	selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
+	selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil)
 	if err != nil {
 		return err
 	}
@@ -321,7 +339,7 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
 
 func (r *Router) matchRule(
 	ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
-	inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
+	inputConn net.Conn, inputPacketConn N.PacketConn,
 ) (
 	selectedRule adapter.Rule, selectedRuleIndex int,
 	buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
@@ -416,24 +434,10 @@ func (r *Router) matchRule(
 	}
 
 match:
-	for ruleIndex < len(r.rules) {
-		rules := r.rules
-		if ruleIndex != -1 {
-			rules = rules[ruleIndex+1:]
-		}
-		var (
-			currentRule      adapter.Rule
-			currentRuleIndex int
-			matched          bool
-		)
-		for currentRuleIndex, currentRule = range rules {
-			if currentRule.Match(metadata) {
-				matched = true
-				break
-			}
-		}
-		if !matched {
-			break
+	for currentRuleIndex, currentRule := range r.rules {
+		metadata.ResetRuleCache()
+		if !currentRule.Match(metadata) {
+			continue
 		}
 		if !preMatch {
 			ruleDescription := currentRule.String()
@@ -444,7 +448,7 @@ match:
 			}
 		} else {
 			switch currentRule.Action().Type() {
-			case C.RuleActionTypeReject, C.RuleActionTypeResolve:
+			case C.RuleActionTypeReject:
 				ruleDescription := currentRule.String()
 				if ruleDescription != "" {
 					r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
@@ -454,6 +458,12 @@ match:
 			}
 		}
 		switch action := currentRule.Action().(type) {
+		case *rule.RuleActionRoute:
+			metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
+			metadata.UDPConnect = action.UDPConnect
+		case *rule.RuleActionRouteOptions:
+			metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
+			metadata.UDPConnect = action.UDPConnect
 		case *rule.RuleActionSniff:
 			if !preMatch {
 				newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
@@ -476,12 +486,16 @@ match:
 			if fatalErr != nil {
 				return
 			}
-		default:
+		}
+		actionType := currentRule.Action().Type()
+		if actionType == C.RuleActionTypeRoute ||
+			actionType == C.RuleActionTypeReject ||
+			actionType == C.RuleActionTypeHijackDNS ||
+			(actionType == C.RuleActionTypeSniff && preMatch) {
 			selectedRule = currentRule
 			selectedRuleIndex = currentRuleIndex
 			break match
 		}
-		ruleIndex = currentRuleIndex
 	}
 	if !preMatch && metadata.Destination.Addr.IsUnspecified() {
 		newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)

+ 82 - 25
route/route_dns.go

@@ -8,8 +8,10 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
 	R "github.com/sagernet/sing-box/route/rule"
 	"github.com/sagernet/sing-dns"
+	tun "github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing/common/cache"
 	E "github.com/sagernet/sing/common/exceptions"
 	F "github.com/sagernet/sing/common/format"
@@ -48,38 +50,63 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
 		if ruleIndex != -1 {
 			dnsRules = dnsRules[ruleIndex+1:]
 		}
-		for currentRuleIndex, rule := range dnsRules {
-			if rule.WithAddressLimit() && !isAddressQuery {
+		for currentRuleIndex, currentRule := range dnsRules {
+			if currentRule.WithAddressLimit() && !isAddressQuery {
 				continue
 			}
 			metadata.ResetRuleCache()
-			if rule.Match(metadata) {
+			if currentRule.Match(metadata) {
 				displayRuleIndex := currentRuleIndex
 				if displayRuleIndex != -1 {
 					displayRuleIndex += displayRuleIndex + 1
 				}
-				if routeAction, isRoute := rule.Action().(*R.RuleActionDNSRoute); isRoute {
-					transport, loaded := r.transportMap[routeAction.Server]
+				ruleDescription := currentRule.String()
+				if ruleDescription != "" {
+					r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action())
+				} else {
+					r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
+				}
+				switch action := currentRule.Action().(type) {
+				case *R.RuleActionDNSRoute:
+					transport, loaded := r.transportMap[action.Server]
 					if !loaded {
-						r.dnsLogger.ErrorContext(ctx, "transport not found: ", routeAction.Server)
+						r.dnsLogger.ErrorContext(ctx, "transport not found: ", action.Server)
 						continue
 					}
 					_, isFakeIP := transport.(adapter.FakeIPTransport)
 					if isFakeIP && !allowFakeIP {
 						continue
 					}
-					options.DisableCache = isFakeIP || routeAction.DisableCache
-					options.RewriteTTL = routeAction.RewriteTTL
-					options.ClientSubnet = routeAction.ClientSubnet
+					if isFakeIP || action.DisableCache {
+						options.DisableCache = true
+					}
+					if action.RewriteTTL != nil {
+						options.RewriteTTL = action.RewriteTTL
+					}
+					if action.ClientSubnet.IsValid() {
+						options.ClientSubnet = action.ClientSubnet
+					}
 					if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
 						options.Strategy = domainStrategy
 					} else {
 						options.Strategy = r.defaultDomainStrategy
 					}
-					r.dnsLogger.DebugContext(ctx, "match[", displayRuleIndex, "] ", rule.String(), " => ", rule.Action())
-					return transport, options, rule, currentRuleIndex
-				} else {
-					return nil, options, rule, currentRuleIndex
+					r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
+					return transport, options, currentRule, currentRuleIndex
+				case *R.RuleActionDNSRouteOptions:
+					if action.DisableCache {
+						options.DisableCache = true
+					}
+					if action.RewriteTTL != nil {
+						options.RewriteTTL = action.RewriteTTL
+					}
+					if action.ClientSubnet.IsValid() {
+						options.ClientSubnet = action.ClientSubnet
+					}
+					r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
+				case *R.RuleActionReject:
+					r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
+					return nil, options, currentRule, currentRuleIndex
 				}
 			}
 		}
@@ -93,9 +120,19 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
 }
 
 func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
-	if len(message.Question) > 0 {
-		r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()))
+	if len(message.Question) != 1 {
+		r.dnsLogger.WarnContext(ctx, "bad question size: ", len(message.Question))
+		responseMessage := mDNS.Msg{
+			MsgHdr: mDNS.MsgHdr{
+				Id:       message.Id,
+				Response: true,
+				Rcode:    mDNS.RcodeFormatError,
+			},
+			Question: message.Question,
+		}
+		return &responseMessage, nil
 	}
+	r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()))
 	var (
 		response  *mDNS.Msg
 		cached    bool
@@ -107,16 +144,14 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
 		var metadata *adapter.InboundContext
 		ctx, metadata = adapter.ExtendContext(ctx)
 		metadata.Destination = M.Socksaddr{}
-		if len(message.Question) > 0 {
-			metadata.QueryType = message.Question[0].Qtype
-			switch metadata.QueryType {
-			case mDNS.TypeA:
-				metadata.IPVersion = 4
-			case mDNS.TypeAAAA:
-				metadata.IPVersion = 6
-			}
-			metadata.Domain = fqdnToDomain(message.Question[0].Name)
+		metadata.QueryType = message.Question[0].Qtype
+		switch metadata.QueryType {
+		case mDNS.TypeA:
+			metadata.IPVersion = 4
+		case mDNS.TypeAAAA:
+			metadata.IPVersion = 6
 		}
+		metadata.Domain = fqdnToDomain(message.Question[0].Name)
 		var (
 			options   dns.QueryOptions
 			rule      adapter.DNSRule
@@ -127,6 +162,17 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
 			dnsCtx := adapter.OverrideContext(ctx)
 			var addressLimit bool
 			transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message))
+			if rule != nil {
+				switch action := rule.Action().(type) {
+				case *R.RuleActionReject:
+					switch action.Method {
+					case C.RuleActionRejectMethodDefault:
+						return dns.FixedResponse(message.Id, message.Question[0], nil, 0), nil
+					case C.RuleActionRejectMethodDrop:
+						return nil, tun.ErrDrop
+					}
+				}
+			}
 			if rule != nil && rule.WithAddressLimit() {
 				addressLimit = true
 				response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(response *mDNS.Msg) bool {
@@ -164,7 +210,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
 	if err != nil {
 		return nil, err
 	}
-	if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
+	if r.dnsReverseMapping != nil && response != nil && len(response.Answer) > 0 {
 		if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP {
 			for _, answer := range response.Answer {
 				switch record := answer.(type) {
@@ -238,6 +284,17 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
 			if strategy != dns.DomainStrategyAsIS {
 				options.Strategy = strategy
 			}
+			if rule != nil {
+				switch action := rule.Action().(type) {
+				case *R.RuleActionReject:
+					switch action.Method {
+					case C.RuleActionRejectMethodDefault:
+						return nil, nil
+					case C.RuleActionRejectMethodDrop:
+						return nil, tun.ErrDrop
+					}
+				}
+			}
 			if rule != nil && rule.WithAddressLimit() {
 				addressLimit = true
 				responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {

+ 94 - 17
route/rule/rule_action.go

@@ -5,9 +5,11 @@ import (
 	"net/netip"
 	"strings"
 	"sync"
+	"syscall"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/sniff"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
@@ -17,19 +19,42 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 	F "github.com/sagernet/sing/common/format"
 	"github.com/sagernet/sing/common/logger"
-
-	"golang.org/x/sys/unix"
+	N "github.com/sagernet/sing/common/network"
 )
 
-func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
+func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
 	switch action.Action {
+	case "":
+		return nil, nil
 	case C.RuleActionTypeRoute:
 		return &RuleActionRoute{
-			Outbound:                  action.RouteOptions.Outbound,
-			UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,
+			Outbound: action.RouteOptions.Outbound,
+		}, nil
+	case C.RuleActionTypeRouteOptions:
+		return &RuleActionRouteOptions{
+			UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping,
+			UDPConnect:                action.RouteOptionsOptions.UDPConnect,
+		}, nil
+	case C.RuleActionTypeDirect:
+		directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions))
+		if err != nil {
+			return nil, err
+		}
+		var description string
+		descriptions := action.DirectOptions.Descriptions()
+		switch len(descriptions) {
+		case 0:
+		case 1:
+			description = F.ToString("(", descriptions[0], ")")
+		case 2:
+			description = F.ToString("(", descriptions[0], ",", descriptions[1], ")")
+		default:
+			description = F.ToString("(", descriptions[0], ",", descriptions[1], ",...)")
+		}
+		return &RuleActionDirect{
+			Dialer:      directDialer,
+			description: description,
 		}, nil
-	case C.RuleActionTypeReturn:
-		return &RuleActionReturn{}, nil
 	case C.RuleActionTypeReject:
 		return &RuleActionReject{
 			Method: action.RejectOptions.Method,
@@ -56,6 +81,8 @@ func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapt
 
 func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) adapter.RuleAction {
 	switch action.Action {
+	case "":
+		return nil
 	case C.RuleActionTypeRoute:
 		return &RuleActionDNSRoute{
 			Server:       action.RouteOptions.Server,
@@ -63,8 +90,12 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
 			RewriteTTL:   action.RouteOptions.RewriteTTL,
 			ClientSubnet: action.RouteOptions.ClientSubnet.Build(),
 		}
-	case C.RuleActionTypeReturn:
-		return &RuleActionReturn{}
+	case C.RuleActionTypeRouteOptions:
+		return &RuleActionDNSRouteOptions{
+			DisableCache: action.RouteOptionsOptions.DisableCache,
+			RewriteTTL:   action.RouteOptionsOptions.RewriteTTL,
+			ClientSubnet: action.RouteOptionsOptions.ClientSubnet.Build(),
+		}
 	case C.RuleActionTypeReject:
 		return &RuleActionReject{
 			Method: action.RejectOptions.Method,
@@ -77,8 +108,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction)
 }
 
 type RuleActionRoute struct {
-	Outbound                  string
-	UDPDisableDomainUnmapping bool
+	Outbound string
 }
 
 func (r *RuleActionRoute) Type() string {
@@ -89,6 +119,26 @@ func (r *RuleActionRoute) String() string {
 	return F.ToString("route(", r.Outbound, ")")
 }
 
+type RuleActionRouteOptions struct {
+	UDPDisableDomainUnmapping bool
+	UDPConnect                bool
+}
+
+func (r *RuleActionRouteOptions) Type() string {
+	return C.RuleActionTypeRouteOptions
+}
+
+func (r *RuleActionRouteOptions) String() string {
+	var descriptions []string
+	if r.UDPDisableDomainUnmapping {
+		descriptions = append(descriptions, "udp-disable-domain-unmapping")
+	}
+	if r.UDPConnect {
+		descriptions = append(descriptions, "udp-connect")
+	}
+	return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
+}
+
 type RuleActionDNSRoute struct {
 	Server       string
 	DisableCache bool
@@ -104,14 +154,41 @@ func (r *RuleActionDNSRoute) String() string {
 	return F.ToString("route(", r.Server, ")")
 }
 
-type RuleActionReturn struct{}
+type RuleActionDNSRouteOptions struct {
+	DisableCache bool
+	RewriteTTL   *uint32
+	ClientSubnet netip.Prefix
+}
+
+func (r *RuleActionDNSRouteOptions) Type() string {
+	return C.RuleActionTypeRouteOptions
+}
+
+func (r *RuleActionDNSRouteOptions) String() string {
+	var descriptions []string
+	if r.DisableCache {
+		descriptions = append(descriptions, "disable-cache")
+	}
+	if r.RewriteTTL != nil {
+		descriptions = append(descriptions, F.ToString("rewrite-ttl(", *r.RewriteTTL, ")"))
+	}
+	if r.ClientSubnet.IsValid() {
+		descriptions = append(descriptions, F.ToString("client-subnet(", r.ClientSubnet, ")"))
+	}
+	return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
+}
+
+type RuleActionDirect struct {
+	Dialer      N.Dialer
+	description string
+}
 
-func (r *RuleActionReturn) Type() string {
-	return C.RuleActionTypeReturn
+func (r *RuleActionDirect) Type() string {
+	return C.RuleActionTypeDirect
 }
 
-func (r *RuleActionReturn) String() string {
-	return "return"
+func (r *RuleActionDirect) String() string {
+	return "direct" + r.description
 }
 
 type RuleActionReject struct {
@@ -137,7 +214,7 @@ func (r *RuleActionReject) Error(ctx context.Context) error {
 	var returnErr error
 	switch r.Method {
 	case C.RuleActionRejectMethodDefault:
-		returnErr = unix.ECONNREFUSED
+		returnErr = syscall.ECONNREFUSED
 	case C.RuleActionRejectMethodDrop:
 		return tun.ErrDrop
 	default:

+ 2 - 2
route/rule/rule_default.go

@@ -52,7 +52,7 @@ type RuleItem interface {
 }
 
 func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
-	action, err := NewRuleAction(logger, options.RuleAction)
+	action, err := NewRuleAction(router, logger, options.RuleAction)
 	if err != nil {
 		return nil, E.Cause(err, "action")
 	}
@@ -254,7 +254,7 @@ type LogicalRule struct {
 }
 
 func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
-	action, err := NewRuleAction(logger, options.RuleAction)
+	action, err := NewRuleAction(router, logger, options.RuleAction)
 	if err != nil {
 		return nil, E.Cause(err, "action")
 	}