Browse Source

cancel failed grpc connection (#707)

Co-authored-by: Shelikhoo <[email protected]>
yuhan6665 4 years ago
parent
commit
0f79126379
1 changed files with 14 additions and 5 deletions
  1. 14 5
      transport/internet/grpc/dial.go

+ 14 - 5
transport/internet/grpc/dial.go

@@ -39,6 +39,8 @@ type dialerConf struct {
 	*internet.MemoryStreamConfig
 }
 
+type dialerCanceller func()
+
 var (
 	globalDialerMap    map[dialerConf]*grpc.ClientConn
 	globalDialerAccess sync.Mutex
@@ -47,8 +49,7 @@ var (
 func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
 	grpcSettings := streamSettings.ProtocolSettings.(*Config)
 
-	conn, err := getGrpcClient(ctx, dest, streamSettings)
-
+	conn, canceller, err := getGrpcClient(ctx, dest, streamSettings)
 	if err != nil {
 		return nil, newError("Cannot dial gRPC").Base(err)
 	}
@@ -57,6 +58,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 		newError("using gRPC multi mode").AtDebug().WriteToLog()
 		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName())
 		if err != nil {
+			canceller()
 			return nil, newError("Cannot dial gRPC").Base(err)
 		}
 		return encoding.NewMultiHunkConn(grpcService, nil), nil
@@ -64,13 +66,14 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 
 	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName())
 	if err != nil {
+		canceller()
 		return nil, newError("Cannot dial gRPC").Base(err)
 	}
 
 	return encoding.NewHunkConn(grpcService, nil), nil
 }
 
-func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) {
+func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, dialerCanceller, error) {
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
 
@@ -81,8 +84,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
 	sockopt := streamSettings.SocketSettings
 	grpcSettings := streamSettings.ProtocolSettings.(*Config)
 
+	canceller := func() {
+		globalDialerAccess.Lock()
+		defer globalDialerAccess.Unlock()
+		delete(globalDialerMap, dialerConf{dest, streamSettings})
+	}
+
 	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown {
-		return client, nil
+		return client, canceller, nil
 	}
 
 	var dialOptions = []grpc.DialOption{
@@ -147,5 +156,5 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
 		dialOptions...,
 	)
 	globalDialerMap[dialerConf{dest, streamSettings}] = conn
-	return conn, err
+	return conn, canceller, err
 }