Browse Source

Remove TLS requirement for gRPC client

Hellojack 2 năm trước cách đây
mục cha
commit
ec2d0b6b3c

+ 0 - 3
transport/v2ray/transport.go

@@ -48,9 +48,6 @@ func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socks
 	case C.V2RayTransportTypeHTTP:
 		return v2rayhttp.NewClient(ctx, dialer, serverAddr, options.HTTPOptions, tlsConfig)
 	case C.V2RayTransportTypeGRPC:
-		if tlsConfig == nil {
-			return nil, C.ErrTLSRequired
-		}
 		return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig)
 	case C.V2RayTransportTypeWebsocket:
 		return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil

+ 3 - 1
transport/v2raygrpc/client.go

@@ -36,7 +36,9 @@ type Client struct {
 func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
 	var dialOptions []grpc.DialOption
 	if tlsConfig != nil {
-		tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
+		if len(tlsConfig.NextProtos()) == 0 {
+			tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
+		}
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(NewTLSTransportCredentials(tlsConfig)))
 	} else {
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))

+ 35 - 33
transport/v2raygrpclite/client.go

@@ -2,7 +2,6 @@ package v2raygrpclite
 
 import (
 	"context"
-	"fmt"
 	"io"
 	"net"
 	"net/http"
@@ -13,6 +12,7 @@ import (
 	"github.com/sagernet/sing-box/common/tls"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-box/transport/v2rayhttp"
+	F "github.com/sagernet/sing/common/format"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
@@ -31,56 +31,56 @@ type Client struct {
 	ctx        context.Context
 	dialer     N.Dialer
 	serverAddr M.Socksaddr
-	transport  http.RoundTripper
+	transport  *http2.Transport
 	options    option.V2RayGRPCOptions
 	url        *url.URL
 }
 
 func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
-	var transport http.RoundTripper
-	if tlsConfig == nil {
-		transport = &http.Transport{
-			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-			},
-		}
-	} else {
-		tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
-		transport = &http2.Transport{
-			ReadIdleTimeout: time.Duration(options.IdleTimeout),
-			PingTimeout:     time.Duration(options.PingTimeout),
-			DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
-				conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-				if err != nil {
-					return nil, err
-				}
-				return tls.ClientHandshake(ctx, conn, tlsConfig)
-			},
-		}
-	}
-	return &Client{
+	client := &Client{
 		ctx:        ctx,
 		dialer:     dialer,
 		serverAddr: serverAddr,
 		options:    options,
-		transport:  transport,
+		transport: &http2.Transport{
+			ReadIdleTimeout:    time.Duration(options.IdleTimeout),
+			PingTimeout:        time.Duration(options.PingTimeout),
+			DisableCompression: true,
+		},
 		url: &url.URL{
 			Scheme: "https",
 			Host:   serverAddr.String(),
-			Path:   fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)),
+			Path:   F.ToString("/", url.QueryEscape(options.ServiceName), "/Tun"),
 		},
 	}
+
+	if tlsConfig == nil {
+		client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
+			return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
+		}
+	} else {
+		if len(tlsConfig.NextProtos()) == 0 {
+			tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
+		}
+		client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
+			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
+			if err != nil {
+				return nil, err
+			}
+			return tls.ClientHandshake(ctx, conn, tlsConfig)
+		}
+	}
+
+	return client
 }
 
 func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	pipeInReader, pipeInWriter := io.Pipe()
 	request := &http.Request{
-		Method:     http.MethodPost,
-		Body:       pipeInReader,
-		URL:        c.url,
-		Proto:      "HTTP/2",
-		ProtoMajor: 2,
-		Header:     defaultClientHeader,
+		Method: http.MethodPost,
+		Body:   pipeInReader,
+		URL:    c.url,
+		Header: defaultClientHeader,
 	}
 	request = request.WithContext(ctx)
 	conn := newLateGunConn(pipeInWriter)
@@ -96,6 +96,8 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 }
 
 func (c *Client) Close() error {
-	v2rayhttp.CloseIdleConnections(c.transport)
+	if c.transport != nil {
+		v2rayhttp.CloseIdleConnections(c.transport)
+	}
 	return nil
 }

+ 1 - 0
transport/v2raygrpclite/conn.go

@@ -117,6 +117,7 @@ func (c *GunConn) WriteBuffer(buffer *buf.Buffer) error {
 	dataLen := buffer.Len()
 	varLen := rw.UVariantLen(uint64(dataLen))
 	header := buffer.ExtendHeader(6 + varLen)
+	_ = header[6]
 	header[0] = 0x00
 	binary.BigEndian.PutUint32(header[1:5], uint32(1+varLen+dataLen))
 	header[5] = 0x0A

+ 3 - 1
transport/v2rayhttp/client.go

@@ -43,7 +43,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 			},
 		}
 	} else {
-		tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
+		if len(tlsConfig.NextProtos()) == 0 {
+			tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
+		}
 		transport = &http2.Transport{
 			ReadIdleTimeout: time.Duration(options.IdleTimeout),
 			PingTimeout:     time.Duration(options.PingTimeout),