瀏覽代碼

Add loopback check

世界 2 年之前
父節點
當前提交
6843970536
共有 9 個文件被更改,包括 73 次插入35 次删除
  1. 2 1
      box.go
  2. 1 1
      go.mod
  3. 2 2
      go.sum
  4. 1 1
      include/dhcp_stub.go
  5. 1 1
      include/quic_stub.go
  6. 21 20
      outbound/builder.go
  7. 14 0
      outbound/lookback.go
  8. 23 7
      route/router.go
  9. 8 2
      transport/dhcp/server.go

+ 2 - 1
box.go

@@ -117,6 +117,7 @@ func New(options Options) (*Box, error) {
 			ctx,
 			router,
 			logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")),
+			tag,
 			outboundOptions)
 		if err != nil {
 			return nil, E.Cause(err, "parse outbound[", i, "]")
@@ -124,7 +125,7 @@ func New(options Options) (*Box, error) {
 		outbounds = append(outbounds, out)
 	}
 	err = router.Initialize(inbounds, outbounds, func() adapter.Outbound {
-		out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), option.Outbound{Type: "direct", Tag: "default"})
+		out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.Outbound{Type: "direct", Tag: "default"})
 		common.Must(oErr)
 		outbounds = append(outbounds, out)
 		return out

+ 1 - 1
go.mod

@@ -26,7 +26,7 @@ require (
 	github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
 	github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2
-	github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855
+	github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440
 	github.com/sagernet/sing-shadowsocks v0.2.0
 	github.com/sagernet/sing-shadowtls v0.1.0
 	github.com/sagernet/sing-tun v0.1.4-0.20230326080954-8848c0e4cbab

+ 2 - 2
go.sum

@@ -113,8 +113,8 @@ github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2
 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
 github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2 h1:VjeHDxEgpB2fqK5G16yBvtLacibvg3h2MsIjal0UXH0=
 github.com/sagernet/sing v0.2.2-0.20230407053809-308e421e33c2/go.mod h1:9uHswk2hITw8leDbiLS/xn0t9nzBcbePxzm9PJhwdlw=
-github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855 h1:a3W2X1n5C/oYGp/Dd26eoymME3iXN8TJq7LZtO2MSUY=
-github.com/sagernet/sing-dns v0.1.5-0.20230407055526-2a27418e7855/go.mod h1:69PNSHyEmXdjf6C+bXBOdr2GQnPeEyWjIzo/MV8gmz8=
+github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440 h1:VH8/BcOVuApHtS+vKP+khxlGRcXH7KKhgkTDtNynqSQ=
+github.com/sagernet/sing-dns v0.1.5-0.20230408004833-5adaf486d440/go.mod h1:69PNSHyEmXdjf6C+bXBOdr2GQnPeEyWjIzo/MV8gmz8=
 github.com/sagernet/sing-shadowsocks v0.2.0 h1:ILDWL7pwWfkPLEbviE/MyCgfjaBmJY/JVVY+5jhSb58=
 github.com/sagernet/sing-shadowsocks v0.2.0/go.mod h1:ysYzszRLpNzJSorvlWRMuzU6Vchsp7sd52q+JNY4axw=
 github.com/sagernet/sing-shadowtls v0.1.0 h1:05MYce8aR5xfKIn+y7xRFsdKhKt44QZTSEQW+lG5IWQ=

+ 1 - 1
include/dhcp_stub.go

@@ -12,7 +12,7 @@ import (
 )
 
 func init() {
-	dns.RegisterTransport([]string{"dhcp"}, func(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
+	dns.RegisterTransport([]string{"dhcp"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
 		return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`)
 	})
 }

+ 1 - 1
include/quic_stub.go

@@ -19,7 +19,7 @@ import (
 const WithQUIC = false
 
 func init() {
-	dns.RegisterTransport([]string{"quic", "h3"}, func(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
+	dns.RegisterTransport([]string{"quic", "h3"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
 		return nil, C.ErrQUICNotIncluded
 	})
 	v2ray.RegisterQUICConstructor(

+ 21 - 20
outbound/builder.go

@@ -10,50 +10,51 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 )
 
-func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.Outbound) (adapter.Outbound, error) {
+func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Outbound) (adapter.Outbound, error) {
 	var metadata *adapter.InboundContext
-	if options.Tag != "" {
+	if tag != "" {
 		ctx, metadata = adapter.AppendContext(ctx)
-		metadata.Outbound = options.Tag
+		metadata.Outbound = tag
 	}
 	if options.Type == "" {
 		return nil, E.New("missing outbound type")
 	}
+	ctx = ContextWithTag(ctx, tag)
 	switch options.Type {
 	case C.TypeDirect:
-		return NewDirect(router, logger, options.Tag, options.DirectOptions)
+		return NewDirect(router, logger, tag, options.DirectOptions)
 	case C.TypeBlock:
-		return NewBlock(logger, options.Tag), nil
+		return NewBlock(logger, tag), nil
 	case C.TypeDNS:
-		return NewDNS(router, options.Tag), nil
+		return NewDNS(router, tag), nil
 	case C.TypeSocks:
-		return NewSocks(router, logger, options.Tag, options.SocksOptions)
+		return NewSocks(router, logger, tag, options.SocksOptions)
 	case C.TypeHTTP:
-		return NewHTTP(router, logger, options.Tag, options.HTTPOptions)
+		return NewHTTP(router, logger, tag, options.HTTPOptions)
 	case C.TypeShadowsocks:
-		return NewShadowsocks(ctx, router, logger, options.Tag, options.ShadowsocksOptions)
+		return NewShadowsocks(ctx, router, logger, tag, options.ShadowsocksOptions)
 	case C.TypeVMess:
-		return NewVMess(ctx, router, logger, options.Tag, options.VMessOptions)
+		return NewVMess(ctx, router, logger, tag, options.VMessOptions)
 	case C.TypeTrojan:
-		return NewTrojan(ctx, router, logger, options.Tag, options.TrojanOptions)
+		return NewTrojan(ctx, router, logger, tag, options.TrojanOptions)
 	case C.TypeWireGuard:
-		return NewWireGuard(ctx, router, logger, options.Tag, options.WireGuardOptions)
+		return NewWireGuard(ctx, router, logger, tag, options.WireGuardOptions)
 	case C.TypeHysteria:
-		return NewHysteria(ctx, router, logger, options.Tag, options.HysteriaOptions)
+		return NewHysteria(ctx, router, logger, tag, options.HysteriaOptions)
 	case C.TypeTor:
-		return NewTor(ctx, router, logger, options.Tag, options.TorOptions)
+		return NewTor(ctx, router, logger, tag, options.TorOptions)
 	case C.TypeSSH:
-		return NewSSH(ctx, router, logger, options.Tag, options.SSHOptions)
+		return NewSSH(ctx, router, logger, tag, options.SSHOptions)
 	case C.TypeShadowTLS:
-		return NewShadowTLS(ctx, router, logger, options.Tag, options.ShadowTLSOptions)
+		return NewShadowTLS(ctx, router, logger, tag, options.ShadowTLSOptions)
 	case C.TypeShadowsocksR:
-		return NewShadowsocksR(ctx, router, logger, options.Tag, options.ShadowsocksROptions)
+		return NewShadowsocksR(ctx, router, logger, tag, options.ShadowsocksROptions)
 	case C.TypeVLESS:
-		return NewVLESS(ctx, router, logger, options.Tag, options.VLESSOptions)
+		return NewVLESS(ctx, router, logger, tag, options.VLESSOptions)
 	case C.TypeSelector:
-		return NewSelector(router, logger, options.Tag, options.SelectorOptions)
+		return NewSelector(router, logger, tag, options.SelectorOptions)
 	case C.TypeURLTest:
-		return NewURLTest(router, logger, options.Tag, options.URLTestOptions)
+		return NewURLTest(router, logger, tag, options.URLTestOptions)
 	default:
 		return nil, E.New("unknown outbound type: ", options.Type)
 	}

+ 14 - 0
outbound/lookback.go

@@ -0,0 +1,14 @@
+package outbound
+
+import "context"
+
+type outboundTagKey struct{}
+
+func ContextWithTag(ctx context.Context, outboundTag string) context.Context {
+	return context.WithValue(ctx, outboundTagKey{}, outboundTag)
+}
+
+func TagFromContext(ctx context.Context) (string, bool) {
+	value, loaded := ctx.Value(outboundTagKey{}).(string)
+	return value, loaded
+}

+ 23 - 7
route/router.go

@@ -26,6 +26,7 @@ import (
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/ntp"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/outbound"
 	"github.com/sagernet/sing-dns"
 	"github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing-vmess"
@@ -218,7 +219,7 @@ func NewRouter(
 					}
 				}
 			}
-			transport, err := dns.CreateTransport(ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address)
+			transport, err := dns.CreateTransport(tag, ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address)
 			if err != nil {
 				return nil, E.Cause(err, "parse dns server[", tag, "]")
 			}
@@ -258,7 +259,7 @@ func NewRouter(
 	}
 	if defaultTransport == nil {
 		if len(transports) == 0 {
-			transports = append(transports, dns.NewLocalTransport(N.SystemDialer))
+			transports = append(transports, dns.NewLocalTransport("local", N.SystemDialer))
 		}
 		defaultTransport = transports[0]
 	}
@@ -660,9 +661,11 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 		metadata.DestinationAddresses = addresses
 		r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
 	}
-	matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForConnection)
+	ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForConnection)
+	if err != nil {
+		return err
+	}
 	if !common.Contains(detour.Network(), N.NetworkTCP) {
-		conn.Close()
 		return E.New("missing supported outbound, closing connection")
 	}
 	if r.clashServer != nil {
@@ -738,9 +741,11 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 		metadata.DestinationAddresses = addresses
 		r.dnsLogger.DebugContext(ctx, "resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
 	}
-	matchedRule, detour := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection)
+	ctx, matchedRule, detour, err := r.match(ctx, &metadata, r.defaultOutboundForPacketConnection)
+	if err != nil {
+		return err
+	}
 	if !common.Contains(detour.Network(), N.NetworkUDP) {
-		conn.Close()
 		return E.New("missing supported outbound, closing packet connection")
 	}
 	if r.clashServer != nil {
@@ -756,7 +761,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 	return detour.NewPacketConnection(ctx, conn, metadata)
 }
 
-func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
+func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) {
+	matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound)
+	if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded {
+		if contextOutbound == matchOutbound.Tag() {
+			return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]")
+		}
+	}
+	ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag())
+	return ctx, matchRule, matchOutbound, nil
+}
+
+func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
 	if r.processSearcher != nil {
 		var originDestination netip.AddrPort
 		if metadata.OriginDestination.IsValid() {

+ 8 - 2
transport/dhcp/server.go

@@ -35,6 +35,7 @@ func init() {
 }
 
 type Transport struct {
+	name              string
 	ctx               context.Context
 	router            adapter.Router
 	logger            logger.Logger
@@ -46,7 +47,7 @@ type Transport struct {
 	updatedAt         time.Time
 }
 
-func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
+func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
 	linkURL, err := url.Parse(link)
 	if err != nil {
 		return nil, err
@@ -59,6 +60,7 @@ func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dia
 		return nil, E.New("missing router in context")
 	}
 	transport := &Transport{
+		name:          name,
 		ctx:           ctx,
 		router:        router,
 		logger:        logger,
@@ -68,6 +70,10 @@ func NewTransport(ctx context.Context, logger logger.ContextLogger, dialer N.Dia
 	return transport, nil
 }
 
+func (t *Transport) Name() string {
+	return t.name
+}
+
 func (t *Transport) Start() error {
 	err := t.fetchServers()
 	if err != nil {
@@ -247,7 +253,7 @@ func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Ad
 	})
 	var transports []dns.Transport
 	for _, serverAddr := range serverAddrs {
-		serverTransport, err := dns.NewUDPTransport(t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53})
+		serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53})
 		if err != nil {
 			return err
 		}