Переглянути джерело

refactor: connection manager

世界 11 місяців тому
батько
коміт
fd299a0961
11 змінених файлів з 569 додано та 52 видалено
  1. 14 0
      adapter/connections.go
  2. 11 7
      box.go
  3. 1 1
      protocol/direct/inbound.go
  4. 332 0
      route/conn.go
  5. 128 0
      route/conn_monitor.go
  6. 43 0
      route/conn_monitor_test.go
  7. 3 3
      route/dns.go
  8. 4 4
      route/geo_resources.go
  9. 8 0
      route/network.go
  10. 14 28
      route/route.go
  11. 11 9
      route/router.go

+ 14 - 0
adapter/connections.go

@@ -0,0 +1,14 @@
+package adapter
+
+import (
+	"context"
+	"net"
+
+	N "github.com/sagernet/sing/common/network"
+)
+
+type ConnectionManager interface {
+	Lifecycle
+	NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
+	NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc)
+}

+ 11 - 7
box.go

@@ -36,9 +36,10 @@ type Box struct {
 	logFactory log.Factory
 	logger     log.ContextLogger
 	network    *route.NetworkManager
-	router     *route.Router
 	inbound    *inbound.Manager
 	outbound   *outbound.Manager
+	connection *route.ConnectionManager
+	router     *route.Router
 	services   []adapter.LifecycleService
 	done       chan struct{}
 }
@@ -128,6 +129,8 @@ func New(options Options) (*Box, error) {
 		return nil, E.Cause(err, "initialize network manager")
 	}
 	service.MustRegister[adapter.NetworkManager](ctx, networkManager)
+	connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
+	service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
 	router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
 	if err != nil {
 		return nil, E.Cause(err, "initialize router")
@@ -238,9 +241,10 @@ func New(options Options) (*Box, error) {
 	}
 	return &Box{
 		network:    networkManager,
-		router:     router,
 		inbound:    inboundManager,
 		outbound:   outboundManager,
+		connection: connectionManager,
+		router:     router,
 		createdAt:  createdAt,
 		logFactory: logFactory,
 		logger:     logFactory.Logger(),
@@ -299,11 +303,11 @@ func (s *Box) preStart() error {
 	if err != nil {
 		return err
 	}
-	err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound)
+	err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound)
 	if err != nil {
 		return err
 	}
-	err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.router)
+	err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
 	if err != nil {
 		return err
 	}
@@ -323,7 +327,7 @@ func (s *Box) start() error {
 	if err != nil {
 		return err
 	}
-	err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound)
+	err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound)
 	if err != nil {
 		return err
 	}
@@ -331,7 +335,7 @@ func (s *Box) start() error {
 	if err != nil {
 		return err
 	}
-	err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound)
+	err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound)
 	if err != nil {
 		return err
 	}
@@ -350,7 +354,7 @@ func (s *Box) Close() error {
 		close(s.done)
 	}
 	err := common.Close(
-		s.inbound, s.outbound, s.router, s.network,
+		s.inbound, s.outbound, s.router, s.connection, s.network,
 	)
 	for _, lifecycleService := range s.services {
 		err = E.Append(err, lifecycleService.Close(), func(err error) error {

+ 1 - 1
protocol/direct/inbound.go

@@ -83,7 +83,7 @@ func (i *Inbound) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
 		destination = i.overrideDestination
 	case 2:
 		destination = i.overrideDestination
-		destination.Port = source.Port
+		destination.Port = i.listener.UDPAddr().Port
 	case 3:
 		destination = source
 		destination.Port = i.overrideDestination.Port

+ 332 - 0
route/conn.go

@@ -0,0 +1,332 @@
+package route
+
+import (
+	"context"
+	"io"
+	"net"
+	"net/netip"
+	"sync/atomic"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+)
+
+var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
+
+type ConnectionManager struct {
+	logger  logger.ContextLogger
+	monitor *ConnectionMonitor
+}
+
+func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
+	return &ConnectionManager{
+		logger:  logger,
+		monitor: NewConnectionMonitor(),
+	}
+}
+
+func (m *ConnectionManager) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateInitialize {
+		return nil
+	}
+	return m.monitor.Start()
+}
+
+func (m *ConnectionManager) Close() error {
+	return m.monitor.Close()
+}
+
+func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
+	ctx = adapter.WithContext(ctx, &metadata)
+	var (
+		remoteConn net.Conn
+		err        error
+	)
+	if len(metadata.DestinationAddresses) > 0 {
+		remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+	} else {
+		remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+	}
+	if err != nil {
+		N.CloseOnHandshakeFailure(conn, onClose, err)
+		m.logger.ErrorContext(ctx, "open outbound connection: ", err)
+		return
+	}
+	err = N.ReportConnHandshakeSuccess(conn, remoteConn)
+	if err != nil {
+		remoteConn.Close()
+		N.CloseOnHandshakeFailure(conn, onClose, err)
+		m.logger.ErrorContext(ctx, "report handshake success: ", err)
+		return
+	}
+	var done atomic.Bool
+	if ctx.Done() != nil {
+		onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
+	}
+	go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
+	go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
+}
+
+func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
+	originSource := source
+	originDestination := destination
+	var readCounters, writeCounters []N.CountFunc
+	for {
+		source, readCounters = N.UnwrapCountReader(source, readCounters)
+		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
+		if cachedSrc, isCached := source.(N.CachedReader); isCached {
+			cachedBuffer := cachedSrc.ReadCached()
+			if cachedBuffer != nil {
+				dataLen := cachedBuffer.Len()
+				_, err := destination.Write(cachedBuffer.Bytes())
+				cachedBuffer.Release()
+				if err != nil {
+					m.logger.ErrorContext(ctx, "connection upload payload: ", err)
+					if done.Swap(true) {
+						if onClose != nil {
+							onClose(err)
+						}
+					}
+					common.Close(originSource, originDestination)
+					return
+				}
+				for _, counter := range readCounters {
+					counter(int64(dataLen))
+				}
+				for _, counter := range writeCounters {
+					counter(int64(dataLen))
+				}
+			}
+			continue
+		}
+		break
+	}
+	_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
+	if err != nil {
+		common.Close(originSource, originDestination)
+	} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
+		err = duplexDst.CloseWrite()
+		if err != nil {
+			common.Close(originSource, originDestination)
+		}
+	} else {
+		common.Close(originDestination)
+	}
+	if done.Swap(true) {
+		if onClose != nil {
+			onClose(err)
+		}
+		common.Close(originSource, originDestination)
+	}
+	if !direction {
+		if err == nil {
+			m.logger.DebugContext(ctx, "connection upload finished")
+		} else if !E.IsClosedOrCanceled(err) {
+			m.logger.ErrorContext(ctx, "connection upload closed: ", err)
+		} else {
+			m.logger.TraceContext(ctx, "connection upload closed")
+		}
+	} else {
+		if err == nil {
+			m.logger.DebugContext(ctx, "connection download finished")
+		} else if !E.IsClosedOrCanceled(err) {
+			m.logger.ErrorContext(ctx, "connection download closed: ", err)
+		} else {
+			m.logger.TraceContext(ctx, "connection download closed")
+		}
+	}
+}
+
+func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
+	ctx = adapter.WithContext(ctx, &metadata)
+	var (
+		remotePacketConn   net.PacketConn
+		remoteConn         net.Conn
+		destinationAddress netip.Addr
+		err                error
+	)
+	if metadata.UDPConnect {
+		if len(metadata.DestinationAddresses) > 0 {
+			if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer {
+				remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+			} else {
+				remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
+			}
+		} else {
+			remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
+		}
+		if err != nil {
+			N.CloseOnHandshakeFailure(conn, onClose, err)
+			m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
+			return
+		}
+		remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
+		connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
+		if connRemoteAddr != metadata.Destination.Addr {
+			destinationAddress = connRemoteAddr
+		}
+	} else {
+		if len(metadata.DestinationAddresses) > 0 {
+			remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+		} else {
+			remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
+		}
+		if err != nil {
+			N.CloseOnHandshakeFailure(conn, onClose, err)
+			m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
+			return
+		}
+	}
+	err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
+	if err != nil {
+		conn.Close()
+		remotePacketConn.Close()
+		m.logger.ErrorContext(ctx, "report handshake success: ", err)
+		return
+	}
+	if destinationAddress.IsValid() {
+		var originDestination M.Socksaddr
+		if metadata.RouteOriginalDestination.IsValid() {
+			originDestination = metadata.RouteOriginalDestination
+		} else {
+			originDestination = metadata.Destination
+		}
+		if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
+			if metadata.UDPDisableDomainUnmapping {
+				remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
+			} else {
+				remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
+			}
+		}
+		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
+			natConn.UpdateDestination(destinationAddress)
+		}
+	}
+	destination := bufio.NewPacketConn(remotePacketConn)
+	var done atomic.Bool
+	if ctx.Done() != nil {
+		onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
+	}
+	go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
+	go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
+}
+
+func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
+	_, err := bufio.CopyPacket(destination, source)
+	/*var readCounters, writeCounters []N.CountFunc
+	var cachedPackets []*N.PacketBuffer
+	originSource := source
+	for {
+		source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
+		destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
+		if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
+			packet := cachedReader.ReadCachedPacket()
+			if packet != nil {
+				cachedPackets = append(cachedPackets, packet)
+				continue
+			}
+		}
+		break
+	}
+	var handled bool
+	if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
+		natConn.SetHandler(&udpHijacker{
+			ctx:           ctx,
+			logger:        m.logger,
+			source:        natConn,
+			destination:   destination,
+			direction:     direction,
+			readCounters:  readCounters,
+			writeCounters: writeCounters,
+			done:          done,
+			onClose:       onClose,
+		})
+		handled = true
+	}
+	if cachedPackets != nil {
+		_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
+		if err != nil {
+			common.Close(source, destination)
+			m.logger.ErrorContext(ctx, "packet upload payload: ", err)
+			return
+		}
+	}
+	if handled {
+		return
+	}
+	_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
+	if !direction {
+		if E.IsClosedOrCanceled(err) {
+			m.logger.TraceContext(ctx, "packet upload closed")
+		} else {
+			m.logger.DebugContext(ctx, "packet upload closed: ", err)
+		}
+	} else {
+		if E.IsClosedOrCanceled(err) {
+			m.logger.TraceContext(ctx, "packet download closed")
+		} else {
+			m.logger.DebugContext(ctx, "packet download closed: ", err)
+		}
+	}
+	if !done.Swap(true) {
+		if onClose != nil {
+			onClose(err)
+		}
+	}
+	common.Close(source, destination)
+}
+
+/*type udpHijacker struct {
+	ctx           context.Context
+	logger        logger.ContextLogger
+	source        io.Closer
+	destination   N.PacketWriter
+	direction     bool
+	readCounters  []N.CountFunc
+	writeCounters []N.CountFunc
+	done          *atomic.Bool
+	onClose       N.CloseHandlerFunc
+}
+
+func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
+	dataLen := buffer.Len()
+	for _, counter := range u.readCounters {
+		counter(int64(dataLen))
+	}
+	err := u.destination.WritePacket(buffer, source)
+	if err != nil {
+		common.Close(u.source, u.destination)
+		u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
+		return
+	}
+	for _, counter := range u.writeCounters {
+		counter(int64(dataLen))
+	}
+}
+
+func (u *udpHijacker) Close() error {
+	var err error
+	if !u.done.Swap(true) {
+		err = common.Close(u.source, u.destination)
+		if u.onClose != nil {
+			u.onClose(net.ErrClosed)
+		}
+	}
+	if u.direction {
+		u.logger.TraceContext(u.ctx, "packet  download closed")
+	} else {
+		u.logger.TraceContext(u.ctx, "packet upload closed")
+	}
+	return err
+}
+
+func (u *udpHijacker) Upstream() any {
+	return u.destination
+}
+*/

+ 128 - 0
route/conn_monitor.go

@@ -0,0 +1,128 @@
+package route
+
+import (
+	"context"
+	"io"
+	"reflect"
+	"sync"
+	"time"
+
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/x/list"
+)
+
+type ConnectionMonitor struct {
+	access      sync.RWMutex
+	reloadChan  chan struct{}
+	connections list.List[*monitorEntry]
+}
+
+type monitorEntry struct {
+	ctx    context.Context
+	closer io.Closer
+}
+
+func NewConnectionMonitor() *ConnectionMonitor {
+	return &ConnectionMonitor{
+		reloadChan: make(chan struct{}, 1),
+	}
+}
+
+func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc {
+	m.access.Lock()
+	defer m.access.Unlock()
+	element := m.connections.PushBack(&monitorEntry{
+		ctx:    ctx,
+		closer: closer,
+	})
+	select {
+	case <-m.reloadChan:
+		return nil
+	default:
+		select {
+		case m.reloadChan <- struct{}{}:
+		default:
+		}
+	}
+	return func(it error) {
+		m.access.Lock()
+		defer m.access.Unlock()
+		m.connections.Remove(element)
+		select {
+		case <-m.reloadChan:
+		default:
+			select {
+			case m.reloadChan <- struct{}{}:
+			default:
+			}
+		}
+	}
+}
+
+func (m *ConnectionMonitor) Start() error {
+	go m.monitor()
+	return nil
+}
+
+func (m *ConnectionMonitor) Close() error {
+	m.access.Lock()
+	defer m.access.Unlock()
+	close(m.reloadChan)
+	for element := m.connections.Front(); element != nil; element = element.Next() {
+		element.Value.closer.Close()
+	}
+	return nil
+}
+
+func (m *ConnectionMonitor) monitor() {
+	var (
+		selectCases []reflect.SelectCase
+		elements    []*list.Element[*monitorEntry]
+	)
+	rootCase := reflect.SelectCase{
+		Dir:  reflect.SelectRecv,
+		Chan: reflect.ValueOf(m.reloadChan),
+	}
+	for {
+		m.access.RLock()
+		if m.connections.Len() == 0 {
+			m.access.RUnlock()
+			if _, loaded := <-m.reloadChan; !loaded {
+				return
+			} else {
+				continue
+			}
+		}
+		if len(elements) < m.connections.Len() {
+			elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len())
+		}
+		if len(selectCases) < m.connections.Len()+1 {
+			selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1)
+		}
+		elements = elements[:0]
+		selectCases = selectCases[:1]
+		selectCases[0] = rootCase
+		for element := m.connections.Front(); element != nil; element = element.Next() {
+			elements = append(elements, element)
+			selectCases = append(selectCases, reflect.SelectCase{
+				Dir:  reflect.SelectRecv,
+				Chan: reflect.ValueOf(element.Value.ctx.Done()),
+			})
+		}
+		m.access.RUnlock()
+		selected, _, loaded := reflect.Select(selectCases)
+		if selected == 0 {
+			if !loaded {
+				return
+			} else {
+				time.Sleep(time.Second)
+				continue
+			}
+		}
+		element := elements[selected-1]
+		m.access.Lock()
+		m.connections.Remove(element)
+		m.access.Unlock()
+		element.Value.closer.Close() // maybe go close
+	}
+}

+ 43 - 0
route/conn_monitor_test.go

@@ -0,0 +1,43 @@
+package route_test
+
+import (
+	"context"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/sagernet/sing-box/route"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestMonitor(t *testing.T) {
+	t.Parallel()
+	var closer myCloser
+	closer.Add(1)
+	monitor := route.NewConnectionMonitor()
+	require.NoError(t, monitor.Start())
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	monitor.Add(ctx, &closer)
+	done := make(chan struct{})
+	go func() {
+		closer.Wait()
+		close(done)
+	}()
+	select {
+	case <-done:
+	case <-time.After(time.Second + 100*time.Millisecond):
+		t.Fatal("timeout")
+	}
+	cancel()
+	require.NoError(t, monitor.Close())
+}
+
+type myCloser struct {
+	sync.WaitGroup
+}
+
+func (c *myCloser) Close() error {
+	c.Done()
+	return nil
+}

+ 3 - 3
route/dns.go

@@ -32,15 +32,15 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad
 }
 
 func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) {
-	if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 {
+	if natConn, isNatConn := conn.(udpnat.Conn); isNatConn {
 		metadata.Destination = M.Socksaddr{}
 		for _, packet := range packetBuffers {
 			buffer := packet.Buffer
 			destination := packet.Destination
 			N.PutPacketBuffer(packet)
-			go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination)
+			go ExchangeDNSPacket(ctx, r, natConn, buffer, metadata, destination)
 		}
-		uConn.SetHandler(&dnsHijacker{
+		natConn.SetHandler(&dnsHijacker{
 			router:   r,
 			conn:     conn,
 			ctx:      ctx,

+ 4 - 4
route/geo_resources.go

@@ -145,13 +145,13 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error {
 	r.logger.Info("downloading geoip database")
 	var detour adapter.Outbound
 	if r.geoIPOptions.DownloadDetour != "" {
-		outbound, loaded := r.outboundManager.Outbound(r.geoIPOptions.DownloadDetour)
+		outbound, loaded := r.outbound.Outbound(r.geoIPOptions.DownloadDetour)
 		if !loaded {
 			return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour)
 		}
 		detour = outbound
 	} else {
-		detour = r.outboundManager.Default()
+		detour = r.outbound.Default()
 	}
 
 	if parentDir := filepath.Dir(savePath); parentDir != "" {
@@ -200,13 +200,13 @@ func (r *Router) downloadGeositeDatabase(savePath string) error {
 	r.logger.Info("downloading geosite database")
 	var detour adapter.Outbound
 	if r.geositeOptions.DownloadDetour != "" {
-		outbound, loaded := r.outboundManager.Outbound(r.geositeOptions.DownloadDetour)
+		outbound, loaded := r.outbound.Outbound(r.geositeOptions.DownloadDetour)
 		if !loaded {
 			return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour)
 		}
 		detour = outbound
 	} else {
-		detour = r.outboundManager.Default()
+		detour = r.outbound.Default()
 	}
 
 	if parentDir := filepath.Dir(savePath); parentDir != "" {

+ 8 - 0
route/network.go

@@ -48,6 +48,7 @@ type NetworkManager struct {
 	powerListener     winpowrprof.EventListener
 	pauseManager      pause.Manager
 	platformInterface platform.Interface
+	inboundManager    adapter.InboundManager
 	outboundManager   adapter.OutboundManager
 	wifiState         adapter.WIFIState
 	started           bool
@@ -354,6 +355,13 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
 func (r *NetworkManager) ResetNetwork() {
 	conntrack.Close()
 
+	for _, inbound := range r.inboundManager.Inbounds() {
+		listener, isListener := inbound.(adapter.InterfaceUpdateListener)
+		if isListener {
+			listener.InterfaceUpdated()
+		}
+	}
+
 	for _, outbound := range r.outboundManager.Outbounds() {
 		listener, isListener := outbound.(adapter.InterfaceUpdateListener)
 		if isListener {

+ 14 - 28
route/route.go

@@ -11,7 +11,6 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/adapter/outbound"
 	"github.com/sagernet/sing-box/common/conntrack"
 	"github.com/sagernet/sing-box/common/process"
 	"github.com/sagernet/sing-box/common/sniff"
@@ -58,7 +57,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 		if metadata.LastInbound == metadata.InboundDetour {
 			return E.New("routing loop on detour: ", metadata.InboundDetour)
 		}
-		detour, loaded := r.inboundManager.Get(metadata.InboundDetour)
+		detour, loaded := r.inbound.Get(metadata.InboundDetour)
 		if !loaded {
 			return E.New("inbound detour not found: ", metadata.InboundDetour)
 		}
@@ -96,7 +95,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 		switch action := selectedRule.Action().(type) {
 		case *rule.RuleActionRoute:
 			var loaded bool
-			selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
+			selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
 			if !loaded {
 				buf.ReleaseMulti(buffers)
 				return E.New("outbound not found: ", action.Outbound)
@@ -118,7 +117,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 		}
 	}
 	if selectedRule == nil {
-		defaultOutbound := r.outboundManager.Default()
+		defaultOutbound := r.outbound.Default()
 		if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) {
 			buf.ReleaseMulti(buffers)
 			return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
@@ -148,19 +147,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
 		}
 		return nil
 	}
-	// TODO
-	err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
-	if err != nil {
-		conn.Close()
-		if onClose != nil {
-			onClose(err)
-		}
-		return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
-	} else {
-		if onClose != nil {
-			onClose(nil)
-		}
-	}
+	r.connection.NewConnection(ctx, selectedOutbound, conn, metadata, onClose)
 	return nil
 }
 
@@ -199,7 +186,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		if metadata.LastInbound == metadata.InboundDetour {
 			return E.New("routing loop on detour: ", metadata.InboundDetour)
 		}
-		detour, loaded := r.inboundManager.Get(metadata.InboundDetour)
+		detour, loaded := r.inbound.Get(metadata.InboundDetour)
 		if !loaded {
 			return E.New("inbound detour not found: ", metadata.InboundDetour)
 		}
@@ -233,7 +220,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		switch action := selectedRule.Action().(type) {
 		case *rule.RuleActionRoute:
 			var loaded bool
-			selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
+			selectedOutbound, loaded = r.outbound.Outbound(action.Outbound)
 			if !loaded {
 				N.ReleaseMultiPacketBuffer(packetBuffers)
 				return E.New("outbound not found: ", action.Outbound)
@@ -252,7 +239,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		}
 	}
 	if selectedRule == nil || selectReturn {
-		defaultOutbound := r.outboundManager.Default()
+		defaultOutbound := r.outbound.Default()
 		if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) {
 			N.ReleaseMultiPacketBuffer(packetBuffers)
 			return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
@@ -278,12 +265,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
 		}
 		return nil
 	}
-	// TODO
-	err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
-	N.CloseOnHandshakeFailure(conn, onClose, err)
-	if err != nil {
-		return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
-	}
+	r.connection.NewPacketConnection(ctx, selectedOutbound, conn, metadata, onClose)
 	return nil
 }
 
@@ -450,8 +432,12 @@ match:
 			}
 			metadata.NetworkStrategy = routeOptions.NetworkStrategy
 			metadata.FallbackDelay = routeOptions.FallbackDelay
-			metadata.UDPDisableDomainUnmapping = routeOptions.UDPDisableDomainUnmapping
-			metadata.UDPConnect = routeOptions.UDPConnect
+			if routeOptions.UDPDisableDomainUnmapping {
+				metadata.UDPDisableDomainUnmapping = true
+			}
+			if routeOptions.UDPConnect {
+				metadata.UDPConnect = true
+			}
 		}
 		switch action := currentRule.Action().(type) {
 		case *rule.RuleActionSniff:

+ 11 - 9
route/router.go

@@ -38,9 +38,10 @@ type Router struct {
 	ctx                     context.Context
 	logger                  log.ContextLogger
 	dnsLogger               log.ContextLogger
-	inboundManager          adapter.InboundManager
-	outboundManager         adapter.OutboundManager
-	networkManager          adapter.NetworkManager
+	inbound                 adapter.InboundManager
+	outbound                adapter.OutboundManager
+	connection              adapter.ConnectionManager
+	network                 adapter.NetworkManager
 	rules                   []adapter.Rule
 	needGeoIPDatabase       bool
 	needGeositeDatabase     bool
@@ -74,9 +75,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
 		ctx:                   ctx,
 		logger:                logFactory.NewLogger("router"),
 		dnsLogger:             logFactory.NewLogger("dns"),
-		inboundManager:        service.FromContext[adapter.InboundManager](ctx),
-		outboundManager:       service.FromContext[adapter.OutboundManager](ctx),
-		networkManager:        service.FromContext[adapter.NetworkManager](ctx),
+		inbound:               service.FromContext[adapter.InboundManager](ctx),
+		outbound:              service.FromContext[adapter.OutboundManager](ctx),
+		connection:            service.FromContext[adapter.ConnectionManager](ctx),
+		network:               service.FromContext[adapter.NetworkManager](ctx),
 		rules:                 make([]adapter.Rule, 0, len(options.Rules)),
 		dnsRules:              make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
 		ruleSetMap:            make(map[string]adapter.RuleSet),
@@ -260,7 +262,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
 				Context: ctx,
 				Name:    "local",
 				Address: "local",
-				Dialer:  common.Must1(dialer.NewDefault(router.networkManager, option.DialerOptions{})),
+				Dialer:  common.Must1(dialer.NewDefault(router.network, option.DialerOptions{})),
 			})))
 		}
 		defaultTransport = transports[0]
@@ -405,7 +407,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
 				monitor.Start("initialize process searcher")
 				searcher, err := process.NewSearcher(process.Config{
 					Logger:         r.logger,
-					PackageManager: r.networkManager.PackageManager(),
+					PackageManager: r.network.PackageManager(),
 				})
 				monitor.Finish()
 				if err != nil {
@@ -507,7 +509,7 @@ func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
 }
 
 func (r *Router) ResetNetwork() {
-	r.networkManager.ResetNetwork()
+	r.network.ResetNetwork()
 	for _, transport := range r.transports {
 		transport.Reset()
 	}