Browse Source

Improve loopback detector

世界 1 year ago
parent
commit
cd0fcd5ddc

+ 1 - 2
experimental/libbox/config.go

@@ -9,7 +9,6 @@ import (
 	"github.com/sagernet/sing-box"
 	"github.com/sagernet/sing-box"
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/process"
 	"github.com/sagernet/sing-box/common/process"
-	"github.com/sagernet/sing-box/experimental/libbox/platform"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing/common/control"
 	"github.com/sagernet/sing/common/control"
@@ -75,7 +74,7 @@ func (s *platformInterfaceStub) UsePlatformInterfaceGetter() bool {
 	return true
 	return true
 }
 }
 
 
-func (s *platformInterfaceStub) Interfaces() ([]platform.NetworkInterface, error) {
+func (s *platformInterfaceStub) Interfaces() ([]control.Interface, error) {
 	return nil, os.ErrInvalid
 	return nil, os.ErrInvalid
 }
 }
 
 

+ 1 - 9
experimental/libbox/platform/interface.go

@@ -2,7 +2,6 @@ package platform
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/netip"
 
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/process"
 	"github.com/sagernet/sing-box/common/process"
@@ -20,16 +19,9 @@ type Interface interface {
 	UsePlatformDefaultInterfaceMonitor() bool
 	UsePlatformDefaultInterfaceMonitor() bool
 	CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor
 	CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor
 	UsePlatformInterfaceGetter() bool
 	UsePlatformInterfaceGetter() bool
-	Interfaces() ([]NetworkInterface, error)
+	Interfaces() ([]control.Interface, error)
 	UnderNetworkExtension() bool
 	UnderNetworkExtension() bool
 	ClearDNSCache()
 	ClearDNSCache()
 	ReadWIFIState() adapter.WIFIState
 	ReadWIFIState() adapter.WIFIState
 	process.Searcher
 	process.Searcher
 }
 }
-
-type NetworkInterface struct {
-	Index     int
-	MTU       int
-	Name      string
-	Addresses []netip.Prefix
-}

+ 3 - 3
experimental/libbox/service.go

@@ -192,14 +192,14 @@ func (w *platformInterfaceWrapper) UsePlatformInterfaceGetter() bool {
 	return w.iif.UsePlatformInterfaceGetter()
 	return w.iif.UsePlatformInterfaceGetter()
 }
 }
 
 
-func (w *platformInterfaceWrapper) Interfaces() ([]platform.NetworkInterface, error) {
+func (w *platformInterfaceWrapper) Interfaces() ([]control.Interface, error) {
 	interfaceIterator, err := w.iif.GetInterfaces()
 	interfaceIterator, err := w.iif.GetInterfaces()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	var interfaces []platform.NetworkInterface
+	var interfaces []control.Interface
 	for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) {
 	for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) {
-		interfaces = append(interfaces, platform.NetworkInterface{
+		interfaces = append(interfaces, control.Interface{
 			Index:     int(netInterface.Index),
 			Index:     int(netInterface.Index),
 			MTU:       int(netInterface.MTU),
 			MTU:       int(netInterface.MTU),
 			Name:      netInterface.Name,
 			Name:      netInterface.Name,

+ 1 - 1
outbound/direct.go

@@ -51,7 +51,7 @@ func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, opti
 		domainStrategy: dns.DomainStrategy(options.DomainStrategy),
 		domainStrategy: dns.DomainStrategy(options.DomainStrategy),
 		fallbackDelay:  time.Duration(options.FallbackDelay),
 		fallbackDelay:  time.Duration(options.FallbackDelay),
 		dialer:         outboundDialer,
 		dialer:         outboundDialer,
-		loopBack:       newLoopBackDetector(),
+		loopBack:       newLoopBackDetector(router),
 	}
 	}
 	if options.ProxyProtocol != 0 {
 	if options.ProxyProtocol != 0 {
 		return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")
 		return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")

+ 14 - 7
outbound/direct_loopback_detect.go

@@ -5,21 +5,22 @@ import (
 	"net/netip"
 	"net/netip"
 	"sync"
 	"sync"
 
 
+	"github.com/sagernet/sing-box/adapter"
 	M "github.com/sagernet/sing/common/metadata"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	N "github.com/sagernet/sing/common/network"
 )
 )
 
 
 type loopBackDetector struct {
 type loopBackDetector struct {
-	// router           adapter.Router
+	router           adapter.Router
 	connAccess       sync.RWMutex
 	connAccess       sync.RWMutex
 	packetConnAccess sync.RWMutex
 	packetConnAccess sync.RWMutex
 	connMap          map[netip.AddrPort]netip.AddrPort
 	connMap          map[netip.AddrPort]netip.AddrPort
 	packetConnMap    map[uint16]uint16
 	packetConnMap    map[uint16]uint16
 }
 }
 
 
-func newLoopBackDetector( /*router adapter.Router*/ ) *loopBackDetector {
+func newLoopBackDetector(router adapter.Router) *loopBackDetector {
 	return &loopBackDetector{
 	return &loopBackDetector{
-		// router:        router,
+		router:        router,
 		connMap:       make(map[netip.AddrPort]netip.AddrPort),
 		connMap:       make(map[netip.AddrPort]netip.AddrPort),
 		packetConnMap: make(map[uint16]uint16),
 		packetConnMap: make(map[uint16]uint16),
 	}
 	}
@@ -31,12 +32,12 @@ func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
 		return conn
 		return conn
 	}
 	}
 	if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
 	if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
-		/*if !source.Addr().IsLoopback() {
+		if !source.Addr().IsLoopback() {
 			_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
 			_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
 			if err != nil {
 			if err != nil {
 				return conn
 				return conn
 			}
 			}
-		}*/
+		}
 		if !N.IsPublicAddr(source.Addr()) {
 		if !N.IsPublicAddr(source.Addr()) {
 			return conn
 			return conn
 		}
 		}
@@ -57,6 +58,12 @@ func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Soc
 	if !source.IsValid() {
 	if !source.IsValid() {
 		return conn
 		return conn
 	}
 	}
+	if !source.Addr().IsLoopback() {
+		_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
+		if err != nil {
+			return conn
+		}
+	}
 	l.packetConnAccess.Lock()
 	l.packetConnAccess.Lock()
 	l.packetConnMap[source.Port()] = destination.AddrPort().Port()
 	l.packetConnMap[source.Port()] = destination.AddrPort().Port()
 	l.packetConnAccess.Unlock()
 	l.packetConnAccess.Unlock()
@@ -74,12 +81,12 @@ func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.Ad
 	if !source.IsValid() {
 	if !source.IsValid() {
 		return false
 		return false
 	}
 	}
-	/*if !source.Addr().IsLoopback() {
+	if !source.Addr().IsLoopback() {
 		_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
 		_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
 		if err != nil {
 		if err != nil {
 			return false
 			return false
 		}
 		}
-	}*/
+	}
 	if N.IsPublicAddr(source.Addr()) {
 	if N.IsPublicAddr(source.Addr()) {
 		return false
 		return false
 	}
 	}

+ 0 - 54
route/interface_finder.go

@@ -1,54 +0,0 @@
-package route
-
-import (
-	"net"
-
-	"github.com/sagernet/sing/common/control"
-)
-
-var _ control.InterfaceFinder = (*myInterfaceFinder)(nil)
-
-type myInterfaceFinder struct {
-	interfaces []net.Interface
-}
-
-func (f *myInterfaceFinder) update() error {
-	ifs, err := net.Interfaces()
-	if err != nil {
-		return err
-	}
-	f.interfaces = ifs
-	return nil
-}
-
-func (f *myInterfaceFinder) updateInterfaces(interfaces []net.Interface) {
-	f.interfaces = interfaces
-}
-
-func (f *myInterfaceFinder) InterfaceIndexByName(name string) (interfaceIndex int, err error) {
-	for _, netInterface := range f.interfaces {
-		if netInterface.Name == name {
-			return netInterface.Index, nil
-		}
-	}
-	netInterface, err := net.InterfaceByName(name)
-	if err != nil {
-		return
-	}
-	f.update()
-	return netInterface.Index, nil
-}
-
-func (f *myInterfaceFinder) InterfaceNameByIndex(index int) (interfaceName string, err error) {
-	for _, netInterface := range f.interfaces {
-		if netInterface.Index == index {
-			return netInterface.Name, nil
-		}
-	}
-	netInterface, err := net.InterfaceByIndex(index)
-	if err != nil {
-		return
-	}
-	f.update()
-	return netInterface.Name, nil
-}

+ 6 - 11
route/router.go

@@ -79,7 +79,7 @@ type Router struct {
 	transportDomainStrategy            map[dns.Transport]dns.DomainStrategy
 	transportDomainStrategy            map[dns.Transport]dns.DomainStrategy
 	dnsReverseMapping                  *DNSReverseMapping
 	dnsReverseMapping                  *DNSReverseMapping
 	fakeIPStore                        adapter.FakeIPStore
 	fakeIPStore                        adapter.FakeIPStore
-	interfaceFinder                    myInterfaceFinder
+	interfaceFinder                    *control.DefaultInterfaceFinder
 	autoDetectInterface                bool
 	autoDetectInterface                bool
 	defaultInterface                   string
 	defaultInterface                   string
 	defaultMark                        int
 	defaultMark                        int
@@ -124,6 +124,7 @@ func NewRouter(
 		needFindProcess:       hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
 		needFindProcess:       hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
 		defaultDetour:         options.Final,
 		defaultDetour:         options.Final,
 		defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy),
 		defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy),
+		interfaceFinder:       control.NewDefaultInterfaceFinder(),
 		autoDetectInterface:   options.AutoDetectInterface,
 		autoDetectInterface:   options.AutoDetectInterface,
 		defaultInterface:      options.DefaultInterface,
 		defaultInterface:      options.DefaultInterface,
 		defaultMark:           options.DefaultMark,
 		defaultMark:           options.DefaultMark,
@@ -305,7 +306,7 @@ func NewRouter(
 			}
 			}
 			router.networkMonitor = networkMonitor
 			router.networkMonitor = networkMonitor
 			networkMonitor.RegisterCallback(func() {
 			networkMonitor.RegisterCallback(func() {
-				_ = router.interfaceFinder.update()
+				_ = router.interfaceFinder.Update()
 			})
 			})
 			interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, router.logger, tun.DefaultInterfaceMonitorOptions{
 			interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, router.logger, tun.DefaultInterfaceMonitorOptions{
 				OverrideAndroidVPN:    options.OverrideAndroidVPN,
 				OverrideAndroidVPN:    options.OverrideAndroidVPN,
@@ -1063,24 +1064,18 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d
 }
 }
 
 
 func (r *Router) InterfaceFinder() control.InterfaceFinder {
 func (r *Router) InterfaceFinder() control.InterfaceFinder {
-	return &r.interfaceFinder
+	return r.interfaceFinder
 }
 }
 
 
 func (r *Router) UpdateInterfaces() error {
 func (r *Router) UpdateInterfaces() error {
 	if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() {
 	if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() {
-		return r.interfaceFinder.update()
+		return r.interfaceFinder.Update()
 	} else {
 	} else {
 		interfaces, err := r.platformInterface.Interfaces()
 		interfaces, err := r.platformInterface.Interfaces()
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		r.interfaceFinder.updateInterfaces(common.Map(interfaces, func(it platform.NetworkInterface) net.Interface {
-			return net.Interface{
-				Name:  it.Name,
-				Index: it.Index,
-				MTU:   it.MTU,
-			}
-		}))
+		r.interfaceFinder.UpdateInterfaces(interfaces)
 		return nil
 		return nil
 	}
 	}
 }
 }