Browse Source

Fix h2c transport

世界 2 years ago
parent
commit
5510c474c7
3 changed files with 29 additions and 8 deletions
  1. 2 0
      transport/v2raygrpc/client.go
  2. 14 4
      transport/v2raygrpclite/server.go
  3. 13 4
      transport/v2rayhttp/server.go

+ 2 - 0
transport/v2raygrpc/client.go

@@ -13,6 +13,7 @@ import (
 	M "github.com/sagernet/sing/common/metadata"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	N "github.com/sagernet/sing/common/network"
 
 
+	"golang.org/x/net/http2"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/backoff"
 	"google.golang.org/grpc/backoff"
 	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/connectivity"
@@ -34,6 +35,7 @@ type Client struct {
 func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
 func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
 	var dialOptions []grpc.DialOption
 	var dialOptions []grpc.DialOption
 	if tlsConfig != nil {
 	if tlsConfig != nil {
+		tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(NewTLSTransportCredentials(tlsConfig)))
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(NewTLSTransportCredentials(tlsConfig)))
 	} else {
 	} else {
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
 		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))

+ 14 - 4
transport/v2raygrpclite/server.go

@@ -19,6 +19,7 @@ import (
 	sHttp "github.com/sagernet/sing/protocol/http"
 	sHttp "github.com/sagernet/sing/protocol/http"
 
 
 	"golang.org/x/net/http2"
 	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/h2c"
 )
 )
 
 
 var _ adapter.V2RayServerTransport = (*Server)(nil)
 var _ adapter.V2RayServerTransport = (*Server)(nil)
@@ -27,6 +28,8 @@ type Server struct {
 	handler      N.TCPConnectionHandler
 	handler      N.TCPConnectionHandler
 	errorHandler E.Handler
 	errorHandler E.Handler
 	httpServer   *http.Server
 	httpServer   *http.Server
+	h2Server     *http2.Server
+	h2cHandler   http.Handler
 	path         string
 	path         string
 }
 }
 
 
@@ -39,10 +42,12 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig t
 		handler:      handler,
 		handler:      handler,
 		errorHandler: errorHandler,
 		errorHandler: errorHandler,
 		path:         fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)),
 		path:         fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)),
+		h2Server:     new(http2.Server),
 	}
 	}
 	server.httpServer = &http.Server{
 	server.httpServer = &http.Server{
 		Handler: server,
 		Handler: server,
 	}
 	}
+	server.h2cHandler = h2c.NewHandler(server, server.h2Server)
 	if tlsConfig != nil {
 	if tlsConfig != nil {
 		stdConfig, err := tlsConfig.Config()
 		stdConfig, err := tlsConfig.Config()
 		if err != nil {
 		if err != nil {
@@ -57,7 +62,12 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig t
 }
 }
 
 
 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
+		s.h2cHandler.ServeHTTP(writer, request)
+		return
+	}
 	if request.URL.Path != s.path {
 	if request.URL.Path != s.path {
+		request.Write(os.Stdout)
 		writer.WriteHeader(http.StatusNotFound)
 		writer.WriteHeader(http.StatusNotFound)
 		s.badRequest(request, E.New("bad path: ", request.URL.Path))
 		s.badRequest(request, E.New("bad path: ", request.URL.Path))
 		return
 		return
@@ -86,13 +96,13 @@ func (s *Server) badRequest(request *http.Request, err error) {
 }
 }
 
 
 func (s *Server) Serve(listener net.Listener) error {
 func (s *Server) Serve(listener net.Listener) error {
+	err := http2.ConfigureServer(s.httpServer, s.h2Server)
+	if err != nil {
+		return err
+	}
 	if s.httpServer.TLSConfig == nil {
 	if s.httpServer.TLSConfig == nil {
 		return s.httpServer.Serve(listener)
 		return s.httpServer.Serve(listener)
 	} else {
 	} else {
-		err := http2.ConfigureServer(s.httpServer, &http2.Server{})
-		if err != nil {
-			return err
-		}
 		return s.httpServer.ServeTLS(listener, "", "")
 		return s.httpServer.ServeTLS(listener, "", "")
 	}
 	}
 }
 }

+ 13 - 4
transport/v2rayhttp/server.go

@@ -18,6 +18,7 @@ import (
 	sHttp "github.com/sagernet/sing/protocol/http"
 	sHttp "github.com/sagernet/sing/protocol/http"
 
 
 	"golang.org/x/net/http2"
 	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/h2c"
 )
 )
 
 
 var _ adapter.V2RayServerTransport = (*Server)(nil)
 var _ adapter.V2RayServerTransport = (*Server)(nil)
@@ -27,6 +28,8 @@ type Server struct {
 	handler      N.TCPConnectionHandler
 	handler      N.TCPConnectionHandler
 	errorHandler E.Handler
 	errorHandler E.Handler
 	httpServer   *http.Server
 	httpServer   *http.Server
+	h2Server     *http2.Server
+	h2cHandler   http.Handler
 	host         []string
 	host         []string
 	path         string
 	path         string
 	method       string
 	method       string
@@ -42,6 +45,7 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
 		ctx:          ctx,
 		ctx:          ctx,
 		handler:      handler,
 		handler:      handler,
 		errorHandler: errorHandler,
 		errorHandler: errorHandler,
+		h2Server:     new(http2.Server),
 		host:         options.Host,
 		host:         options.Host,
 		path:         options.Path,
 		path:         options.Path,
 		method:       options.Method,
 		method:       options.Method,
@@ -61,6 +65,7 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
 		ReadHeaderTimeout: C.TCPTimeout,
 		ReadHeaderTimeout: C.TCPTimeout,
 		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
 		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
 	}
 	}
+	server.h2cHandler = h2c.NewHandler(server, server.h2Server)
 	if tlsConfig != nil {
 	if tlsConfig != nil {
 		stdConfig, err := tlsConfig.Config()
 		stdConfig, err := tlsConfig.Config()
 		if err != nil {
 		if err != nil {
@@ -72,6 +77,10 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
 }
 }
 
 
 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
+		s.h2cHandler.ServeHTTP(writer, request)
+		return
+	}
 	host := request.Host
 	host := request.Host
 	if len(s.host) > 0 && !common.Contains(s.host, host) {
 	if len(s.host) > 0 && !common.Contains(s.host, host) {
 		writer.WriteHeader(http.StatusBadRequest)
 		writer.WriteHeader(http.StatusBadRequest)
@@ -124,13 +133,13 @@ func (s *Server) badRequest(request *http.Request, err error) {
 }
 }
 
 
 func (s *Server) Serve(listener net.Listener) error {
 func (s *Server) Serve(listener net.Listener) error {
+	err := http2.ConfigureServer(s.httpServer, s.h2Server)
+	if err != nil {
+		return err
+	}
 	if s.httpServer.TLSConfig == nil {
 	if s.httpServer.TLSConfig == nil {
 		return s.httpServer.Serve(listener)
 		return s.httpServer.Serve(listener)
 	} else {
 	} else {
-		err := http2.ConfigureServer(s.httpServer, &http2.Server{})
-		if err != nil {
-			return err
-		}
 		return s.httpServer.ServeTLS(listener, "", "")
 		return s.httpServer.ServeTLS(listener, "", "")
 	}
 	}
 }
 }