1
0
世界 1 жил өмнө
parent
commit
98ff897f35

+ 1 - 0
adapter/v2ray.go

@@ -22,4 +22,5 @@ type V2RayServerTransportHandler interface {
 
 type V2RayClientTransport interface {
 	DialContext(ctx context.Context) (net.Conn, error)
+	Close() error
 }

+ 2 - 2
outbound/hysteria.go

@@ -130,8 +130,8 @@ func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, m
 	return NewPacketConnection(ctx, h, conn, metadata)
 }
 
-func (h *Hysteria) InterfaceUpdated() error {
-	return h.client.CloseWithError(E.New("network changed"))
+func (h *Hysteria) InterfaceUpdated() {
+	h.client.CloseWithError(E.New("network changed"))
 }
 
 func (h *Hysteria) Close() error {

+ 2 - 2
outbound/hysteria2.go

@@ -116,8 +116,8 @@ func (h *Hysteria2) NewPacketConnection(ctx context.Context, conn N.PacketConn,
 	return NewPacketConnection(ctx, h, conn, metadata)
 }
 
-func (h *Hysteria2) InterfaceUpdated() error {
-	return h.client.CloseWithError(E.New("network changed"))
+func (h *Hysteria2) InterfaceUpdated() {
+	h.client.CloseWithError(E.New("network changed"))
 }
 
 func (h *Hysteria2) Close() error {

+ 3 - 0
outbound/trojan.go

@@ -108,6 +108,9 @@ func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, met
 }
 
 func (h *Trojan) InterfaceUpdated() {
+	if h.transport != nil {
+		h.transport.Close()
+	}
 	if h.multiplexDialer != nil {
 		h.multiplexDialer.Reset()
 	}

+ 3 - 0
outbound/vless.go

@@ -127,6 +127,9 @@ func (h *VLESS) NewPacketConnection(ctx context.Context, conn N.PacketConn, meta
 }
 
 func (h *VLESS) InterfaceUpdated() {
+	if h.transport != nil {
+		h.transport.Close()
+	}
 	if h.multiplexDialer != nil {
 		h.multiplexDialer.Reset()
 	}

+ 3 - 0
outbound/vmess.go

@@ -103,6 +103,9 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg
 }
 
 func (h *VMess) InterfaceUpdated() {
+	if h.transport != nil {
+		h.transport.Close()
+	}
 	if h.multiplexDialer != nil {
 		h.multiplexDialer.Reset()
 	}

+ 10 - 6
transport/v2raygrpc/client.go

@@ -72,12 +72,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	}, nil
 }
 
-func (c *Client) Close() error {
-	return common.Close(
-		common.PtrOrNil(c.conn),
-	)
-}
-
 func (c *Client) connect() (*grpc.ClientConn, error) {
 	conn := c.conn
 	if conn != nil && conn.GetState() != connectivity.Shutdown {
@@ -113,3 +107,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	}
 	return NewGRPCConn(stream, cancel), nil
 }
+
+func (c *Client) Close() error {
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	if c.conn != nil {
+		c.conn.Close()
+		c.conn = nil
+	}
+	return nil
+}

+ 1 - 3
transport/v2raygrpclite/client.go

@@ -109,8 +109,6 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 }
 
 func (c *Client) Close() error {
-	if c.transport != nil {
-		v2rayhttp.CloseIdleConnections(c.transport)
-	}
+	v2rayhttp.ResetTransport(c.transport)
 	return nil
 }

+ 1 - 1
transport/v2rayhttp/client.go

@@ -155,6 +155,6 @@ func (c *Client) dialHTTP2(ctx context.Context) (net.Conn, error) {
 }
 
 func (c *Client) Close() error {
-	CloseIdleConnections(c.transport)
+	c.transport = ResetTransport(c.transport)
 	return nil
 }

+ 47 - 0
transport/v2rayhttp/force_close.go

@@ -0,0 +1,47 @@
+package v2rayhttp
+
+import (
+	"net/http"
+	"reflect"
+	"sync"
+	"unsafe"
+
+	E "github.com/sagernet/sing/common/exceptions"
+
+	"golang.org/x/net/http2"
+)
+
+type clientConnPool struct {
+	t     *http2.Transport
+	mu    sync.Mutex
+	conns map[string][]*http2.ClientConn // key is host:port
+}
+
+type efaceWords struct {
+	typ  unsafe.Pointer
+	data unsafe.Pointer
+}
+
+func ResetTransport(rawTransport http.RoundTripper) http.RoundTripper {
+	switch transport := rawTransport.(type) {
+	case *http.Transport:
+		transport.CloseIdleConnections()
+		return transport.Clone()
+	case *http2.Transport:
+		connPool := transportConnPool(transport)
+		p := (*clientConnPool)((*efaceWords)(unsafe.Pointer(&connPool)).data)
+		p.mu.Lock()
+		defer p.mu.Unlock()
+		for _, vv := range p.conns {
+			for _, cc := range vv {
+				cc.Close()
+			}
+		}
+		return transport
+	default:
+		panic(E.New("unknown transport type: ", reflect.TypeOf(transport)))
+	}
+}
+
+//go:linkname transportConnPool golang.org/x/net/http2.(*Transport).connPool
+func transportConnPool(t *http2.Transport) http2.ClientConnPool

+ 4 - 0
transport/v2rayhttpupgrade/client.go

@@ -116,3 +116,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	}
 	return conn, nil
 }
+
+func (c *Client) Close() error {
+	return nil
+}

+ 11 - 1
transport/v2rayquic/client.go

@@ -97,5 +97,15 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 }
 
 func (c *Client) Close() error {
-	return common.Close(c.conn, c.rawConn)
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	if c.conn != nil {
+		c.conn.CloseWithError(0, "")
+	}
+	if c.rawConn != nil {
+		c.rawConn.Close()
+	}
+	c.conn = nil
+	c.rawConn = nil
+	return nil
 }

+ 4 - 0
transport/v2raywebsocket/client.go

@@ -127,3 +127,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 		return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
 	}
 }
+
+func (c *Client) Close() error {
+	return nil
+}