浏览代码

Dialer: Set TimeoutOnly for `gctx` and `hctx`

https://github.com/XTLS/Xray-core/issues/2232#issuecomment-1694570914

Thank @cty123 for testing

Fixes https://github.com/XTLS/Xray-core/issues/2232

BTW: Use `uConn.HandshakeContext(ctx)` in REALITY
RPRX 2 年之前
父节点
当前提交
d92002ad12
共有 3 个文件被更改,包括 14 次插入12 次删除
  1. 7 5
      transport/internet/grpc/dial.go
  2. 6 6
      transport/internet/http/dialer.go
  3. 1 1
      transport/internet/reality/reality.go

+ 7 - 5
transport/internet/grpc/dial.go

@@ -98,16 +98,13 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
 			MinConnectTimeout: 5 * time.Second,
 		}),
 		grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
-			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
-			gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
-
-			rawHost, rawPort, err := net.SplitHostPort(s)
 			select {
 			case <-gctx.Done():
 				return nil, gctx.Err()
 			default:
 			}
 
+			rawHost, rawPort, err := net.SplitHostPort(s)
 			if err != nil {
 				return nil, err
 			}
@@ -119,9 +116,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
 				return nil, err
 			}
 			address := net.ParseAddress(rawHost)
+
+			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
+			gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
+			gctx = session.ContextWithTimeoutOnly(gctx, true)
+
 			c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
 			if err == nil && realityConfig != nil {
-				return reality.UClient(c, realityConfig, ctx, dest)
+				return reality.UClient(c, realityConfig, gctx, dest)
 			}
 			return c, err
 		}),

+ 6 - 6
transport/internet/http/dialer.go

@@ -53,7 +53,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 
 	transport := &http2.Transport{
-		DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
+		DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
 			rawHost, rawPort, err := net.SplitHostPort(addr)
 			if err != nil {
 				return nil, err
@@ -67,18 +67,18 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 			}
 			address := net.ParseAddress(rawHost)
 
-			dctx := context.Background()
-			dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
-			dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
+			hctx = session.ContextWithID(hctx, session.IDFromContext(ctx))
+			hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx))
+			hctx = session.ContextWithTimeoutOnly(hctx, true)
 
-			pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
+			pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
 			if err != nil {
 				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
 				return nil, err
 			}
 
 			if realityConfigs != nil {
-				return reality.UClient(pconn, realityConfigs, ctx, dest)
+				return reality.UClient(pconn, realityConfigs, hctx, dest)
 			}
 
 			var cn tls.Interface

+ 1 - 1
transport/internet/reality/reality.go

@@ -156,7 +156,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati
 		aead.Seal(hello.SessionId[:0], hello.Random[20:], hello.SessionId[:16], hello.Raw)
 		copy(hello.Raw[39:], hello.SessionId)
 	}
-	if err := uConn.Handshake(); err != nil {
+	if err := uConn.HandshakeContext(ctx); err != nil {
 		return nil, err
 	}
 	if config.Show {