Browse Source

Fixing tcp connestions leak

- always use HandshakeContext instead of Handshake

- pickup dailer dropped ctx

- rename HandshakeContextAddress to HandshakeAddressContext
deorth-kku 1 year ago
parent
commit
cae94570df

+ 4 - 4
proxy/dokodemo/dokodemo.go

@@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session {
 	return p
 }
 
-type hasHandshakeAddress interface {
-	HandshakeAddress() net.Address
+type hasHandshakeAddressContext interface {
+	HandshakeAddressContext(ctx context.Context) net.Address
 }
 
 // Process implements proxy.Inbound.
@@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 		if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
 			dest = outbound.Target
 			destinationOverridden = true
-		} else if handshake, ok := conn.(hasHandshakeAddress); ok {
-			addr := handshake.HandshakeAddress()
+		} else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
+			addr := handshake.HandshakeAddressContext(ctx)
 			if addr != nil {
 				dest.Address = addr
 				destinationOverridden = true

+ 1 - 1
proxy/http/client.go

@@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 
 	nextProto := ""
 	if tlsConn, ok := iConn.(*tls.Conn); ok {
-		if err := tlsConn.Handshake(); err != nil {
+		if err := tlsConn.HandshakeContext(ctx); err != nil {
 			rawConn.Close()
 			return nil, err
 		}

+ 1 - 1
transport/internet/http/dialer.go

@@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 			} else {
 				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
 			}
-			if err := cn.Handshake(); err != nil {
+			if err := cn.HandshakeContext(ctx); err != nil {
 				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
 				return nil, err
 			}

+ 1 - 1
transport/internet/tcp/dialer.go

@@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
 		if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
 			conn = tls.UClient(conn, tlsConfig, fingerprint)
-			if err := conn.(*tls.UConn).Handshake(); err != nil {
+			if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
 				return nil, err
 			}
 		} else {

+ 1 - 1
transport/internet/tls/grpc.go

@@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon
 	conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
 	errChannel := make(chan error, 1)
 	go func() {
-		errChannel <- conn.Handshake()
+		errChannel <- conn.HandshakeContext(ctx)
 		close(errChannel)
 	}()
 	select {

+ 28 - 8
transport/internet/tls/tls.go

@@ -1,9 +1,11 @@
 package tls
 
 import (
+	"context"
 	"crypto/rand"
 	"crypto/tls"
 	"math/big"
+	"time"
 
 	utls "github.com/refraction-networking/utls"
 	"github.com/xtls/xray-core/common/buf"
@@ -14,7 +16,7 @@ import (
 
 type Interface interface {
 	net.Conn
-	Handshake() error
+	HandshakeContext(ctx context.Context) error
 	VerifyHostname(host string) error
 	NegotiatedProtocol() (name string, mutual bool)
 }
@@ -25,6 +27,16 @@ type Conn struct {
 	*tls.Conn
 }
 
+const tlsCloseTimeout = 250 * time.Millisecond
+
+func (c *Conn) Close() error {
+	timer := time.AfterFunc(tlsCloseTimeout, func() {
+		c.Conn.NetConn().Close()
+	})
+	defer timer.Stop()
+	return c.Conn.Close()
+}
+
 func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	mb = buf.Compact(mb)
 	mb, err := buf.WriteMultiBuffer(c, mb)
@@ -32,8 +44,8 @@ func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	return err
 }
 
-func (c *Conn) HandshakeAddress() net.Address {
-	if err := c.Handshake(); err != nil {
+func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address {
+	if err := c.HandshakeContext(ctx); err != nil {
 		return nil
 	}
 	state := c.ConnectionState()
@@ -64,8 +76,16 @@ type UConn struct {
 	*utls.UConn
 }
 
-func (c *UConn) HandshakeAddress() net.Address {
-	if err := c.Handshake(); err != nil {
+func (c *UConn) Close() error {
+	timer := time.AfterFunc(tlsCloseTimeout, func() {
+		c.Conn.NetConn().Close()
+	})
+	defer timer.Stop()
+	return c.Conn.Close()
+}
+
+func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address {
+	if err := c.HandshakeContext(ctx); err != nil {
 		return nil
 	}
 	state := c.ConnectionState()
@@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address {
 
 // WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
 // http/1.1 in its ALPN.
-func (c *UConn) WebsocketHandshake() error {
+func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error {
 	// Build the handshake state. This will apply every variable of the TLS of the
 	// fingerprint in the UConn
 	if err := c.BuildHandshakeState(); err != nil {
@@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error {
 	if err := c.BuildHandshakeState(); err != nil {
 		return err
 	}
-	return c.Handshake()
+	return c.HandshakeContext(ctx)
 }
 
 func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
@@ -118,7 +138,7 @@ func copyConfig(c *tls.Config) *utls.Config {
 		ServerName:            c.ServerName,
 		InsecureSkipVerify:    c.InsecureSkipVerify,
 		VerifyPeerCertificate: c.VerifyPeerCertificate,
-		KeyLogWriter:	       c.KeyLogWriter,
+		KeyLogWriter:          c.KeyLogWriter,
 	}
 }
 

+ 2 - 2
transport/internet/websocket/dialer.go

@@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 				}
 				// TLS and apply the handshake
 				cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
-				if err := cn.WebsocketHandshake(); err != nil {
+				if err := cn.WebsocketHandshakeContext(ctx); err != nil {
 					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
 					return nil, err
 				}
@@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 		header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
 	}
 
-	conn, resp, err := dialer.Dial(uri, header)
+	conn, resp, err := dialer.DialContext(ctx, uri, header)
 	if err != nil {
 		var reason string
 		if resp != nil {