浏览代码

Fix wireguard reconnect

世界 2 年之前
父节点
当前提交
1fbe7c54bf
共有 4 个文件被更改,包括 36 次插入42 次删除
  1. 13 0
      outbound/default.go
  2. 1 1
      outbound/wireguard.go
  3. 22 19
      transport/wireguard/client_bind.go
  4. 0 22
      transport/wireguard/error.go

+ 13 - 0
outbound/default.go

@@ -39,6 +39,10 @@ func (a *myOutboundAdapter) Network() []string {
 	return a.network
 }
 
+func (a *myOutboundAdapter) NewError(ctx context.Context, err error) {
+	NewError(a.logger, ctx, err)
+}
+
 func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
 	ctx = adapter.WithContext(ctx, &metadata)
 	var outConn net.Conn
@@ -121,3 +125,12 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
 	}
 	return bufio.CopyConn(ctx, conn, serverConn)
 }
+
+func NewError(logger log.ContextLogger, ctx context.Context, err error) {
+	common.Close(err)
+	if E.IsClosedOrCanceled(err) {
+		logger.DebugContext(ctx, "connection closed: ", err)
+		return
+	}
+	logger.ErrorContext(ctx, err)
+}

+ 1 - 1
outbound/wireguard.go

@@ -64,7 +64,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 			connectAddr = options.ServerOptions.Build()
 		}
 	}
-	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
+	outbound.bind = wireguard.NewClientBind(ctx, outbound, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
 	localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
 	if len(localPrefixes) == 0 {
 		return nil, E.New("missing local address")

+ 22 - 19
transport/wireguard/client_bind.go

@@ -7,8 +7,8 @@ import (
 	"sync"
 
 	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
+	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/wireguard-go/conn"
@@ -18,6 +18,7 @@ var _ conn.Bind = (*ClientBind)(nil)
 
 type ClientBind struct {
 	ctx                 context.Context
+	errorHandler        E.Handler
 	dialer              N.Dialer
 	reservedForEndpoint map[M.Socksaddr][3]uint8
 	connAccess          sync.Mutex
@@ -28,9 +29,10 @@ type ClientBind struct {
 	reserved            [3]uint8
 }
 
-func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
+func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
 	return &ClientBind{
 		ctx:                 ctx,
+		errorHandler:        errorHandler,
 		dialer:              dialer,
 		reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
 		isConnect:           isConnect,
@@ -67,10 +69,10 @@ func (c *ClientBind) connect() (*wireConn, error) {
 	if c.isConnect {
 		udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
 		if err != nil {
-			return nil, &wireError{err}
+			return nil, err
 		}
 		c.conn = &wireConn{
-			NetPacketConn: &bufio.UnbindPacketConn{
+			PacketConn: &bufio.UnbindPacketConn{
 				ExtendedConn: bufio.NewExtendedConn(udpConn),
 				Addr:         c.connectAddr,
 			},
@@ -79,11 +81,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
 	} else {
 		udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
 		if err != nil {
-			return nil, &wireError{err}
+			return nil, err
 		}
 		c.conn = &wireConn{
-			NetPacketConn: bufio.NewPacketConn(udpConn),
-			done:          make(chan struct{}),
+			PacketConn: bufio.NewPacketConn(udpConn),
+			done:       make(chan struct{}),
 		}
 	}
 	return c.conn, nil
@@ -102,30 +104,31 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
 func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 	udpConn, err := c.connect()
 	if err != nil {
-		err = &wireError{err}
+		select {
+		case <-c.done:
+			return
+		default:
+		}
+		c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
+		err = nil
 		return
 	}
-	buffer := buf.With(b)
-	destination, err := udpConn.ReadPacket(buffer)
+	n, addr, err := udpConn.ReadFrom(b)
 	if err != nil {
 		udpConn.Close()
 		select {
 		case <-c.done:
 		default:
-			err = &wireError{err}
+			c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
 		}
 		return
 	}
-	n = buffer.Len()
-	if buffer.Start() > 0 {
-		copy(b, buffer.Bytes())
-	}
 	if n > 3 {
 		b[1] = 0
 		b[2] = 0
 		b[3] = 0
 	}
-	ep = Endpoint(destination)
+	ep = Endpoint(M.SocksaddrFromNet(addr))
 	return
 }
 
@@ -167,7 +170,7 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
 		b[2] = reserved[1]
 		b[3] = reserved[2]
 	}
-	err = udpConn.WritePacket(buf.As(b), destination)
+	_, err = udpConn.WriteTo(b, destination)
 	if err != nil {
 		udpConn.Close()
 	}
@@ -179,7 +182,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
 }
 
 type wireConn struct {
-	N.NetPacketConn
+	net.PacketConn
 	access sync.Mutex
 	done   chan struct{}
 }
@@ -192,7 +195,7 @@ func (w *wireConn) Close() error {
 		return net.ErrClosed
 	default:
 	}
-	w.NetPacketConn.Close()
+	w.PacketConn.Close()
 	close(w.done)
 	return nil
 }

+ 0 - 22
transport/wireguard/error.go

@@ -1,22 +0,0 @@
-package wireguard
-
-import "net"
-
-type wireError struct {
-	cause error
-}
-
-func (w *wireError) Error() string {
-	return w.cause.Error()
-}
-
-func (w *wireError) Timeout() bool {
-	if cause, causeNet := w.cause.(net.Error); causeNet {
-		return cause.Timeout()
-	}
-	return false
-}
-
-func (w *wireError) Temporary() bool {
-	return true
-}