浏览代码

Fix wireguard client bind

世界 1 年之前
父节点
当前提交
e08c052fc9
共有 1 个文件被更改,包括 23 次插入3 次删除
  1. 23 3
      transport/wireguard/client_bind.go

+ 23 - 3
transport/wireguard/client_bind.go

@@ -12,6 +12,8 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/service"
+	"github.com/sagernet/sing/service/pause"
 	"github.com/sagernet/wireguard-go/conn"
 )
 
@@ -19,6 +21,9 @@ var _ conn.Bind = (*ClientBind)(nil)
 
 type ClientBind struct {
 	ctx                 context.Context
+	pauseManager        pause.Manager
+	bindCtx             context.Context
+	bindDone            context.CancelFunc
 	errorHandler        E.Handler
 	dialer              N.Dialer
 	reservedForEndpoint map[netip.AddrPort][3]uint8
@@ -33,6 +38,7 @@ type ClientBind struct {
 func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
 	return &ClientBind{
 		ctx:                 ctx,
+		pauseManager:        service.FromContext[pause.Manager](ctx),
 		errorHandler:        errorHandler,
 		dialer:              dialer,
 		reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
@@ -55,6 +61,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
 	}
 	c.connAccess.Lock()
 	defer c.connAccess.Unlock()
+	select {
+	case <-c.done:
+		return nil, net.ErrClosed
+	default:
+	}
 	serverConn = c.conn
 	if serverConn != nil {
 		select {
@@ -65,7 +76,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
 		}
 	}
 	if c.isConnect {
-		udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
+		udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
 		if err != nil {
 			return nil, err
 		}
@@ -74,7 +85,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
 			done:       make(chan struct{}),
 		}
 	} else {
-		udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
+		udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
 		if err != nil {
 			return nil, err
 		}
@@ -92,6 +103,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
 		c.done = make(chan struct{})
 	default:
 	}
+	c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
 	return []conn.ReceiveFunc{c.receive}, 0, nil
 }
 
@@ -105,6 +117,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
 		}
 		c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
 		err = nil
+		c.pauseManager.WaitActive()
 		time.Sleep(time.Second)
 		return
 	}
@@ -130,12 +143,17 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
 }
 
 func (c *ClientBind) Close() error {
-	common.Close(common.PtrOrNil(c.conn))
 	select {
 	case <-c.done:
 	default:
 		close(c.done)
 	}
+	if c.bindDone != nil {
+		c.bindDone()
+	}
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	common.Close(common.PtrOrNil(c.conn))
 	return nil
 }
 
@@ -146,6 +164,8 @@ func (c *ClientBind) SetMark(mark uint32) error {
 func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
 	udpConn, err := c.connect()
 	if err != nil {
+		c.pauseManager.WaitActive()
+		time.Sleep(time.Second)
 		return err
 	}
 	destination := netip.AddrPort(ep.(Endpoint))