浏览代码

Fix reset outbound

世界 2 年之前
父节点
当前提交
1f63ce5dee
共有 6 个文件被更改,包括 63 次插入17 次删除
  1. 4 0
      adapter/router.go
  2. 9 1
      outbound/hysteria.go
  3. 9 1
      outbound/ssh.go
  4. 9 1
      outbound/wireguard.go
  5. 26 14
      route/router.go
  6. 6 0
      transport/wireguard/client_bind.go

+ 4 - 0
adapter/router.go

@@ -60,3 +60,7 @@ type DNSRule interface {
 	Rule
 	DisableCache() bool
 }
+
+type InterfaceUpdateListener interface {
+	InterfaceUpdated() error
+}

+ 9 - 1
outbound/hysteria.go

@@ -23,7 +23,10 @@ import (
 	N "github.com/sagernet/sing/common/network"
 )
 
-var _ adapter.Outbound = (*Hysteria)(nil)
+var (
+	_ adapter.Outbound                = (*Hysteria)(nil)
+	_ adapter.InterfaceUpdateListener = (*Hysteria)(nil)
+)
 
 type Hysteria struct {
 	myOutboundAdapter
@@ -236,6 +239,11 @@ func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
 	}
 }
 
+func (h *Hysteria) InterfaceUpdated() error {
+	h.Close()
+	return nil
+}
+
 func (h *Hysteria) Close() error {
 	h.connAccess.Lock()
 	defer h.connAccess.Unlock()

+ 9 - 1
outbound/ssh.go

@@ -21,7 +21,10 @@ import (
 	"golang.org/x/crypto/ssh"
 )
 
-var _ adapter.Outbound = (*SSH)(nil)
+var (
+	_ adapter.Outbound                = (*SSH)(nil)
+	_ adapter.InterfaceUpdateListener = (*SSH)(nil)
+)
 
 type SSH struct {
 	myOutboundAdapter
@@ -149,6 +152,11 @@ func (s *SSH) connect() (*ssh.Client, error) {
 	return client, nil
 }
 
+func (s *SSH) InterfaceUpdated() error {
+	common.Close(s.clientConn)
+	return nil
+}
+
 func (s *SSH) Close() error {
 	return common.Close(s.clientConn)
 }

+ 9 - 1
outbound/wireguard.go

@@ -26,7 +26,10 @@ import (
 	"golang.zx2c4.com/wireguard/device"
 )
 
-var _ adapter.Outbound = (*WireGuard)(nil)
+var (
+	_ adapter.Outbound                = (*WireGuard)(nil)
+	_ adapter.InterfaceUpdateListener = (*WireGuard)(nil)
+)
 
 type WireGuard struct {
 	myOutboundAdapter
@@ -134,6 +137,11 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 	return outbound, nil
 }
 
+func (w *WireGuard) InterfaceUpdated() error {
+	w.bind.Reset()
+	return nil
+}
+
 func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
 	switch network {
 	case N.NetworkTCP:

+ 26 - 14
route/router.go

@@ -262,20 +262,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
 		if err != nil {
 			return nil, E.New("auto_detect_interface unsupported on current platform")
 		}
-		interfaceMonitor.RegisterCallback(func(event int) error {
-			if C.IsAndroid {
-				var vpnStatus string
-				if router.interfaceMonitor.AndroidVPNEnabled() {
-					vpnStatus = "enabled"
-				} else {
-					vpnStatus = "disabled"
-				}
-				router.logger.Info("updated default interface ", router.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", router.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus)
-			} else {
-				router.logger.Info("updated default interface ", router.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", router.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()))
-			}
-			return nil
-		})
+		interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate)
 		router.interfaceMonitor = interfaceMonitor
 	}
 
@@ -1014,3 +1001,28 @@ func (r *Router) NewError(ctx context.Context, err error) {
 	}
 	r.logger.ErrorContext(ctx, err)
 }
+
+func (r *Router) notifyNetworkUpdate(int) error {
+	if C.IsAndroid {
+		var vpnStatus string
+		if r.interfaceMonitor.AndroidVPNEnabled() {
+			vpnStatus = "enabled"
+		} else {
+			vpnStatus = "disabled"
+		}
+		r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus)
+	} else {
+		r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()))
+	}
+
+	for _, outbound := range r.outbounds {
+		listener, isListener := outbound.(adapter.InterfaceUpdateListener)
+		if isListener {
+			err := listener.InterfaceUpdated()
+			if err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}

+ 6 - 0
transport/wireguard/client_bind.go

@@ -100,6 +100,12 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 	return
 }
 
+func (c *ClientBind) Reset() {
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	common.Close(common.PtrOrNil(c.conn))
+}
+
 func (c *ClientBind) Close() error {
 	c.connAccess.Lock()
 	defer c.connAccess.Unlock()