Prechádzať zdrojové kódy

Add custom path to gRPC (#1815)

Hirbod Behnam 2 rokov pred
rodič
commit
526c6789ed

+ 38 - 2
transport/internet/grpc/config.go

@@ -2,6 +2,7 @@ package grpc
 
 import (
 	"net/url"
+	"strings"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/transport/internet"
@@ -15,6 +16,41 @@ func init() {
 	}))
 }
 
-func (c *Config) getNormalizedName() string {
-	return url.PathEscape(c.ServiceName)
+func (c *Config) getServiceName() string {
+	// Normal old school config
+	if !strings.HasPrefix(c.ServiceName, "/") {
+		return url.PathEscape(c.ServiceName)
+	}
+	// Otherwise new custom paths
+	rawServiceName := c.ServiceName[1:strings.LastIndex(c.ServiceName, "/")] // trim from first to last '/'
+	serviceNameParts := strings.Split(rawServiceName, "/")
+	for i := range serviceNameParts {
+		serviceNameParts[i] = url.PathEscape(serviceNameParts[i])
+	}
+	return strings.Join(serviceNameParts, "/")
+}
+
+func (c *Config) getTunStreamName() string {
+	// Normal old school config
+	if !strings.HasPrefix(c.ServiceName, "/") {
+		return "Tun"
+	}
+	// Otherwise new custom paths
+	endingPath := c.ServiceName[strings.LastIndex(c.ServiceName, "/")+1:] // from the last '/' to end of string
+	return url.PathEscape(strings.Split(endingPath, "|")[0])
+}
+
+func (c *Config) getTunMultiStreamName() string {
+	// Normal old school config
+	if !strings.HasPrefix(c.ServiceName, "/") {
+		return "TunMulti"
+	}
+	// Otherwise new custom paths
+	endingPath := c.ServiceName[strings.LastIndex(c.ServiceName, "/")+1:] // from the last '/' to end of string
+	streamNames := strings.Split(endingPath, "|")
+	if len(streamNames) == 1 { // client side. Service name is the full path to multi tun
+		return url.PathEscape(streamNames[0])
+	} else { // server side. The second part is the path to multi tun
+		return url.PathEscape(streamNames[1])
+	}
 }

+ 111 - 0
transport/internet/grpc/config_test.go

@@ -0,0 +1,111 @@
+package grpc
+
+import (
+	"github.com/stretchr/testify/assert"
+	"testing"
+)
+
+func TestConfig_GetServiceName(t *testing.T) {
+	tests := []struct {
+		TestName    string
+		ServiceName string
+		Expected    string
+	}{
+		{
+			TestName:    "simple no absolute path",
+			ServiceName: "hello",
+			Expected:    "hello",
+		},
+		{
+			TestName:    "escape no absolute path",
+			ServiceName: "hello/world!",
+			Expected:    "hello%2Fworld%21",
+		},
+		{
+			TestName:    "absolute path",
+			ServiceName: "/my/sample/path/a|b",
+			Expected:    "my/sample/path",
+		},
+		{
+			TestName:    "escape absolute path",
+			ServiceName: "/hello /world!/a|b",
+			Expected:    "hello%20/world%21",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.TestName, func(t *testing.T) {
+			config := Config{ServiceName: test.ServiceName}
+			assert.Equal(t, test.Expected, config.getServiceName())
+		})
+	}
+}
+
+func TestConfig_GetTunStreamName(t *testing.T) {
+	tests := []struct {
+		TestName    string
+		ServiceName string
+		Expected    string
+	}{
+		{
+			TestName:    "no absolute path",
+			ServiceName: "hello",
+			Expected:    "Tun",
+		},
+		{
+			TestName:    "absolute path server",
+			ServiceName: "/my/sample/path/tun_service|multi_service",
+			Expected:    "tun_service",
+		},
+		{
+			TestName:    "absolute path client",
+			ServiceName: "/my/sample/path/tun_service",
+			Expected:    "tun_service",
+		},
+		{
+			TestName:    "escape absolute path client",
+			ServiceName: "/m y/sa !mple/pa\\th/tun\\_serv!ice",
+			Expected:    "tun%5C_serv%21ice",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.TestName, func(t *testing.T) {
+			config := Config{ServiceName: test.ServiceName}
+			assert.Equal(t, test.Expected, config.getTunStreamName())
+		})
+	}
+}
+
+func TestConfig_GetTunMultiStreamName(t *testing.T) {
+	tests := []struct {
+		TestName    string
+		ServiceName string
+		Expected    string
+	}{
+		{
+			TestName:    "no absolute path",
+			ServiceName: "hello",
+			Expected:    "TunMulti",
+		},
+		{
+			TestName:    "absolute path server",
+			ServiceName: "/my/sample/path/tun_service|multi_service",
+			Expected:    "multi_service",
+		},
+		{
+			TestName:    "absolute path client",
+			ServiceName: "/my/sample/path/multi_service",
+			Expected:    "multi_service",
+		},
+		{
+			TestName:    "escape absolute path client",
+			ServiceName: "/m y/sa !mple/pa\\th/mu%lti\\_serv!ice",
+			Expected:    "mu%25lti%5C_serv%21ice",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.TestName, func(t *testing.T) {
+			config := Config{ServiceName: test.ServiceName}
+			assert.Equal(t, test.Expected, config.getTunMultiStreamName())
+		})
+	}
+}

+ 4 - 3
transport/internet/grpc/dial.go

@@ -54,15 +54,16 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 	}
 	client := encoding.NewGRPCServiceClient(conn)
 	if grpcSettings.MultiMode {
-		newError("using gRPC multi mode").AtDebug().WriteToLog()
-		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName())
+		newError("using gRPC multi mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog()
+		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunMultiStreamName())
 		if err != nil {
 			return nil, newError("Cannot dial gRPC").Base(err)
 		}
 		return encoding.NewMultiHunkConn(grpcService, nil), nil
 	}
 
-	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName())
+	newError("using gRPC tun mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunStreamName() + "`").AtDebug().WriteToLog()
+	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunStreamName())
 	if err != nil {
 		return nil, newError("Cannot dial gRPC").Base(err)
 	}

+ 11 - 11
transport/internet/grpc/encoding/customSeviceName.go

@@ -6,20 +6,20 @@ import (
 	"google.golang.org/grpc"
 )
 
-func ServerDesc(name string) grpc.ServiceDesc {
+func ServerDesc(name, tun, tunMulti string) grpc.ServiceDesc {
 	return grpc.ServiceDesc{
 		ServiceName: name,
 		HandlerType: (*GRPCServiceServer)(nil),
 		Methods:     []grpc.MethodDesc{},
 		Streams: []grpc.StreamDesc{
 			{
-				StreamName:    "Tun",
+				StreamName:    tun,
 				Handler:       _GRPCService_Tun_Handler,
 				ServerStreams: true,
 				ClientStreams: true,
 			},
 			{
-				StreamName:    "TunMulti",
+				StreamName:    tunMulti,
 				Handler:       _GRPCService_TunMulti_Handler,
 				ServerStreams: true,
 				ClientStreams: true,
@@ -29,8 +29,8 @@ func ServerDesc(name string) grpc.ServiceDesc {
 	}
 }
 
-func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunClient, error) {
-	stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...)
+func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name, tun string, opts ...grpc.CallOption) (GRPCService_TunClient, error) {
+	stream, err := c.cc.NewStream(ctx, &ServerDesc(name, tun, "").Streams[0], "/"+name+"/"+tun, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -38,8 +38,8 @@ func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name string, opts
 	return x, nil
 }
 
-func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) {
-	stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[1], "/"+name+"/TunMulti", opts...)
+func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name, tunMulti string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) {
+	stream, err := c.cc.NewStream(ctx, &ServerDesc(name, "", tunMulti).Streams[1], "/"+name+"/"+tunMulti, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -48,13 +48,13 @@ func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name string,
 }
 
 type GRPCServiceClientX interface {
-	TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunClient, error)
-	TunMultiCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error)
+	TunCustomName(ctx context.Context, name, tun string, opts ...grpc.CallOption) (GRPCService_TunClient, error)
+	TunMultiCustomName(ctx context.Context, name, tunMulti string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error)
 	Tun(ctx context.Context, opts ...grpc.CallOption) (GRPCService_TunClient, error)
 	TunMulti(ctx context.Context, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error)
 }
 
-func RegisterGRPCServiceServerX(s *grpc.Server, srv GRPCServiceServer, name string) {
-	desc := ServerDesc(name)
+func RegisterGRPCServiceServerX(s *grpc.Server, srv GRPCServiceServer, name, tun, tunMulti string) {
+	desc := ServerDesc(name, tun, tunMulti)
 	s.RegisterService(&desc, srv)
 }

+ 2 - 1
transport/internet/grpc/hub.go

@@ -125,7 +125,8 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
 			}
 		}
 
-		encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.getNormalizedName())
+		newError("gRPC listen for service name `" + grpcSettings.getServiceName() + "` tun `" + grpcSettings.getTunStreamName() + "` multi tun `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog()
+		encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.getServiceName(), grpcSettings.getTunStreamName(), grpcSettings.getTunMultiStreamName())
 
 		if config := reality.ConfigFromStreamSettings(settings); config != nil {
 			streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())