Browse Source

Fix: gRPC & HTTP/2 dialer (#445)

Jim Han 4 years ago
parent
commit
3ed14c2fcd
2 changed files with 15 additions and 12 deletions
  1. 12 10
      transport/internet/grpc/dial.go
  2. 3 2
      transport/internet/http/dialer.go

+ 12 - 10
transport/internet/grpc/dial.go

@@ -36,6 +36,7 @@ func init() {
 type dialerConf struct {
 	net.Destination
 	*internet.SocketConfig
+	*tls.Config
 }
 
 var (
@@ -46,14 +47,9 @@ var (
 func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
 	grpcSettings := streamSettings.ProtocolSettings.(*Config)
 
-	config := tls.ConfigFromStreamSettings(streamSettings)
-	var dialOption = grpc.WithInsecure()
+	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 
-	if config != nil {
-		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(config.GetTLSConfig()))
-	}
-
-	conn, err := getGrpcClient(ctx, dest, dialOption, streamSettings.SocketSettings)
+	conn, err := getGrpcClient(ctx, dest, tlsConfig, streamSettings.SocketSettings)
 
 	if err != nil {
 		return nil, newError("Cannot dial gRPC").Base(err)
@@ -76,7 +72,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 	return encoding.NewHunkConn(grpcService, nil), nil
 }
 
-func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) {
+func getGrpcClient(ctx context.Context, dest net.Destination, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) {
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
 
@@ -84,10 +80,16 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
 	}
 
-	if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found && client.GetState() != connectivity.Shutdown {
+	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsConfig}]; found && client.GetState() != connectivity.Shutdown {
 		return client, nil
 	}
 
+	dialOption := grpc.WithInsecure()
+
+	if tlsConfig != nil {
+		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig.GetTLSConfig()))
+	}
+
 	conn, err := grpc.Dial(
 		gonet.JoinHostPort(dest.Address.String(), dest.Port.String()),
 		dialOption,
@@ -125,6 +127,6 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 			return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
 		}),
 	)
-	globalDialerMap[dialerConf{dest, sockopt}] = conn
+	globalDialerMap[dialerConf{dest, sockopt, tlsConfig}] = conn
 	return conn, err
 }

+ 3 - 2
transport/internet/http/dialer.go

@@ -21,6 +21,7 @@ import (
 type dialerConf struct {
 	net.Destination
 	*internet.SocketConfig
+	*tls.Config
 }
 
 var (
@@ -36,7 +37,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 		globalDialerMap = make(map[dialerConf]*http.Client)
 	}
 
-	if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found {
+	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsSettings}]; found {
 		return client, nil
 	}
 
@@ -92,7 +93,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 		Transport: transport,
 	}
 
-	globalDialerMap[dialerConf{dest, sockopt}] = client
+	globalDialerMap[dialerConf{dest, sockopt, tlsSettings}] = client
 	return client, nil
 }