浏览代码

wireguard: Fix events chan leak

世界 1 年之前
父节点
当前提交
9bac18bcd1
共有 3 个文件被更改,包括 54 次插入52 次删除
  1. 5 2
      outbound/wireguard.go
  2. 2 6
      transport/wireguard/device_stack.go
  3. 47 44
      transport/wireguard/device_system.go

+ 5 - 2
outbound/wireguard.go

@@ -152,6 +152,10 @@ func (w *WireGuard) start() error {
 		}
 		bind = wireguard.NewClientBind(w.ctx, w, w.listener, isConnect, connectAddr, reserved)
 	}
+	err = w.tunDevice.Start()
+	if err != nil {
+		return err
+	}
 	wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
 		Verbosef: func(format string, args ...interface{}) {
 			w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
@@ -170,7 +174,7 @@ func (w *WireGuard) start() error {
 	}
 	w.device = wgDevice
 	w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
-	return w.tunDevice.Start()
+	return nil
 }
 
 func (w *WireGuard) Close() error {
@@ -180,7 +184,6 @@ func (w *WireGuard) Close() error {
 	if w.pauseCallback != nil {
 		w.pauseManager.UnregisterCallback(w.pauseCallback)
 	}
-	w.tunDevice.Close()
 	return nil
 }
 

+ 2 - 6
transport/wireguard/device_stack.go

@@ -230,17 +230,13 @@ func (w *StackDevice) Events() <-chan wgTun.Event {
 }
 
 func (w *StackDevice) Close() error {
-	select {
-	case <-w.done:
-		return os.ErrClosed
-	default:
-	}
+	close(w.done)
+	close(w.events)
 	w.stack.Close()
 	for _, endpoint := range w.stack.CleanupEndpoints() {
 		endpoint.Abort()
 	}
 	w.stack.Wait()
-	close(w.done)
 	return nil
 }
 

+ 47 - 44
transport/wireguard/device_system.go

@@ -6,6 +6,7 @@ import (
 	"net"
 	"net/netip"
 	"os"
+	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
@@ -21,14 +22,16 @@ import (
 var _ Device = (*SystemDevice)(nil)
 
 type SystemDevice struct {
-	dialer      N.Dialer
-	device      tun.Tun
-	batchDevice tun.LinuxTUN
-	name        string
-	mtu         int
-	events      chan wgTun.Event
-	addr4       netip.Addr
-	addr6       netip.Addr
+	dialer         N.Dialer
+	device         tun.Tun
+	batchDevice    tun.LinuxTUN
+	name           string
+	mtu            uint32
+	inet4Addresses []netip.Prefix
+	inet6Addresses []netip.Prefix
+	gso            bool
+	events         chan wgTun.Event
+	closeOnce      sync.Once
 }
 
 func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
@@ -44,43 +47,17 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
 	if interfaceName == "" {
 		interfaceName = tun.CalculateInterfaceName("wg")
 	}
-	tunInterface, err := tun.New(tun.Options{
-		Name:         interfaceName,
-		Inet4Address: inet4Addresses,
-		Inet6Address: inet6Addresses,
-		MTU:          mtu,
-		GSO:          gso,
-	})
-	if err != nil {
-		return nil, err
-	}
-	var inet4Address netip.Addr
-	var inet6Address netip.Addr
-	if len(inet4Addresses) > 0 {
-		inet4Address = inet4Addresses[0].Addr()
-	}
-	if len(inet6Addresses) > 0 {
-		inet6Address = inet6Addresses[0].Addr()
-	}
-	var batchDevice tun.LinuxTUN
-	if gso {
-		batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
-		if !isBatchTUN {
-			return nil, E.New("GSO is not supported on current platform")
-		}
-		batchDevice = batchTUN
-	}
+
 	return &SystemDevice{
 		dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{
 			BindInterface: interfaceName,
 		})),
-		device:      tunInterface,
-		batchDevice: batchDevice,
-		name:        interfaceName,
-		mtu:         int(mtu),
-		events:      make(chan wgTun.Event),
-		addr4:       inet4Address,
-		addr6:       inet6Address,
+		name:           interfaceName,
+		mtu:            mtu,
+		inet4Addresses: inet4Addresses,
+		inet6Addresses: inet6Addresses,
+		gso:            gso,
+		events:         make(chan wgTun.Event),
 	}, nil
 }
 
@@ -93,14 +70,39 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
 }
 
 func (w *SystemDevice) Inet4Address() netip.Addr {
-	return w.addr4
+	if len(w.inet4Addresses) == 0 {
+		return netip.Addr{}
+	}
+	return w.inet4Addresses[0].Addr()
 }
 
 func (w *SystemDevice) Inet6Address() netip.Addr {
-	return w.addr6
+	if len(w.inet6Addresses) == 0 {
+		return netip.Addr{}
+	}
+	return w.inet6Addresses[0].Addr()
 }
 
 func (w *SystemDevice) Start() error {
+	tunInterface, err := tun.New(tun.Options{
+		Name:         w.name,
+		Inet4Address: w.inet4Addresses,
+		Inet6Address: w.inet6Addresses,
+		MTU:          w.mtu,
+		GSO:          w.gso,
+	})
+	if err != nil {
+		return err
+	}
+	w.device = tunInterface
+	if w.gso {
+		batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
+		if !isBatchTUN {
+			tunInterface.Close()
+			return E.New("GSO is not supported on current platform")
+		}
+		w.batchDevice = batchTUN
+	}
 	w.events <- wgTun.EventUp
 	return nil
 }
@@ -143,7 +145,7 @@ func (w *SystemDevice) Flush() error {
 }
 
 func (w *SystemDevice) MTU() (int, error) {
-	return w.mtu, nil
+	return int(w.mtu), nil
 }
 
 func (w *SystemDevice) Name() (string, error) {
@@ -155,6 +157,7 @@ func (w *SystemDevice) Events() <-chan wgTun.Event {
 }
 
 func (w *SystemDevice) Close() error {
+	close(w.events)
 	return w.device.Close()
 }