Browse Source

Transport: Add HTTP3 to HTTP (#3819)

yuhan6665 1 year ago
parent
commit
3632e83faa

+ 1 - 1
infra/conf/transport_internet.go

@@ -650,7 +650,7 @@ func (p TransportProtocol) Build() (string, error) {
 		return "mkcp", nil
 	case "ws", "websocket":
 		return "websocket", nil
-	case "h2", "http":
+	case "h2", "h3", "http":
 		return "http", nil
 	case "grpc", "gun":
 		return "grpc", nil

+ 119 - 56
transport/internet/http/dialer.go

@@ -9,6 +9,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/quic-go/quic-go"
+	"github.com/quic-go/quic-go/http3"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	c "github.com/xtls/xray-core/common/ctx"
@@ -24,6 +26,13 @@ import (
 	"golang.org/x/net/http2"
 )
 
+// defines the maximum time an idle TCP session can survive in the tunnel, so
+// it should be consistent across HTTP versions and with other transports.
+const connIdleTimeout = 300 * time.Second
+
+// consistent with quic-go
+const h3KeepalivePeriod = 10 * time.Second
+
 type dialerConf struct {
 	net.Destination
 	*internet.MemoryStreamConfig
@@ -48,72 +57,129 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 	if tlsConfigs == nil && realityConfigs == nil {
 		return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning()
 	}
+	isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3")
+	if isH3 {
+		dest.Network = net.Network_UDP
+	}
 	sockopt := streamSettings.SocketSettings
 
 	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
 		return client, nil
 	}
 
-	transport := &http2.Transport{
-		DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
-			rawHost, rawPort, err := net.SplitHostPort(addr)
-			if err != nil {
-				return nil, err
-			}
-			if len(rawPort) == 0 {
-				rawPort = "443"
-			}
-			port, err := net.PortFromString(rawPort)
-			if err != nil {
-				return nil, err
-			}
-			address := net.ParseAddress(rawHost)
+	var transport http.RoundTripper
+	if isH3 {
+		quicConfig := &quic.Config{
+			MaxIdleTimeout: connIdleTimeout,
 
-			hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
-			hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
-			hctx = session.ContextWithTimeoutOnly(hctx, true)
+			// these two are defaults of quic-go/http3. the default of quic-go (no
+			// http3) is different, so it is hardcoded here for clarity.
+			// https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
+			MaxIncomingStreams: -1,
+			KeepAlivePeriod:    h3KeepalivePeriod,
+		}
+		roundTripper := &http3.RoundTripper{
+			QUICConfig:      quicConfig,
+			TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
+			Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
+				conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
+				if err != nil {
+					return nil, err
+				}
 
-			pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
-			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
-				return nil, err
-			}
+				var udpConn net.PacketConn
+				var udpAddr *net.UDPAddr
 
-			if realityConfigs != nil {
-				return reality.UClient(pconn, realityConfigs, hctx, dest)
-			}
+				switch c := conn.(type) {
+				case *internet.PacketConnWrapper:
+					var ok bool
+					udpConn, ok = c.Conn.(*net.UDPConn)
+					if !ok {
+						return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
+					}
+					udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
+					if err != nil {
+						return nil, err
+					}
+				case *net.UDPConn:
+					udpConn = c
+					udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
+					if err != nil {
+						return nil, err
+					}
+				default:
+					udpConn = &internet.FakePacketConn{c}
+					udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
+					if err != nil {
+						return nil, err
+					}
+				}
 
-			var cn tls.Interface
-			if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
-				cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
-			} else {
-				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
-			}
-			if err := cn.HandshakeContext(ctx); err != nil {
-				errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
-				return nil, err
-			}
-			if !tlsConfig.InsecureSkipVerify {
-				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
+				return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
+			},
+		}
+		transport = roundTripper
+	} else {
+		transportH2 := &http2.Transport{
+			DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
+				rawHost, rawPort, err := net.SplitHostPort(addr)
+				if err != nil {
+					return nil, err
+				}
+				if len(rawPort) == 0 {
+					rawPort = "443"
+				}
+				port, err := net.PortFromString(rawPort)
+				if err != nil {
+					return nil, err
+				}
+				address := net.ParseAddress(rawHost)
+	
+				hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
+				hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
+				hctx = session.ContextWithTimeoutOnly(hctx, true)
+	
+				pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
+				if err != nil {
 					errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
 					return nil, err
 				}
-			}
-			negotiatedProtocol := cn.NegotiatedProtocol()
-			if negotiatedProtocol != http2.NextProtoTLS {
-				return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
-			}
-			return cn, nil
-		},
-	}
-
-	if tlsConfigs != nil {
-		transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
-	}
-
-	if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
-		transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
-		transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
+	
+				if realityConfigs != nil {
+					return reality.UClient(pconn, realityConfigs, hctx, dest)
+				}
+	
+				var cn tls.Interface
+				if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
+					cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
+				} else {
+					cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
+				}
+				if err := cn.HandshakeContext(ctx); err != nil {
+					errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
+					return nil, err
+				}
+				if !tlsConfig.InsecureSkipVerify {
+					if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
+						errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
+						return nil, err
+					}
+				}
+				negotiatedProtocol := cn.NegotiatedProtocol()
+				if negotiatedProtocol != http2.NextProtoTLS {
+					return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
+				}
+				return cn, nil
+			},
+		}
+		if tlsConfigs != nil {
+			transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
+		}
+		if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
+			transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
+			transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
+		}
+		transport = transportH2
 	}
 
 	client := &http.Client{
@@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 			Host:   dest.NetAddr(),
 			Path:   httpSettings.getNormalizedPath(),
 		},
-		Proto:      "HTTP/2",
-		ProtoMajor: 2,
-		ProtoMinor: 0,
 		Header:     httpHeaders,
 	}
 	// Disable any compression method from server.

+ 78 - 0
transport/internet/http/http_test.go

@@ -12,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol/tls/cert"
 	"github.com/xtls/xray-core/testing/servers/tcp"
+	"github.com/xtls/xray-core/testing/servers/udp"
 	"github.com/xtls/xray-core/transport/internet"
 	. "github.com/xtls/xray-core/transport/internet/http"
 	"github.com/xtls/xray-core/transport/internet/stat"
@@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) {
 		t.Error(r)
 	}
 }
+
+func TestH3Connection(t *testing.T) {
+	port := udp.PickPort()
+
+	listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
+		ProtocolName:     "http",
+		ProtocolSettings: &Config{},
+		SecurityType:     "tls",
+		SecuritySettings: &tls.Config{
+			NextProtocol: []string{"h3"},
+			Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))},
+		},
+	}, func(conn stat.Connection) {
+		go func() {
+			defer conn.Close()
+
+			b := buf.New()
+			defer b.Release()
+
+			for {
+				if _, err := b.ReadFrom(conn); err != nil {
+					return
+				}
+				_, err := conn.Write(b.Bytes())
+				common.Must(err)
+			}
+		}()
+	})
+	common.Must(err)
+
+	defer listener.Close()
+
+	time.Sleep(time.Second)
+
+	dctx := context.Background()
+	conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
+		ProtocolName:     "http",
+		ProtocolSettings: &Config{},
+		SecurityType:     "tls",
+		SecuritySettings: &tls.Config{
+			NextProtocol: []string{"h3"},
+			ServerName:    "www.example.com",
+			AllowInsecure: true,
+		},
+	})
+	common.Must(err)
+	defer conn.Close()
+
+	const N = 1024
+	b1 := make([]byte, N)
+	common.Must2(rand.Read(b1))
+	b2 := buf.New()
+
+	nBytes, err := conn.Write(b1)
+	common.Must(err)
+	if nBytes != N {
+		t.Error("write: ", nBytes)
+	}
+
+	b2.Clear()
+	common.Must2(b2.ReadFullFrom(conn, N))
+	if r := cmp.Diff(b2.Bytes(), b1); r != "" {
+		t.Error(r)
+	}
+
+	nBytes, err = conn.Write(b1)
+	common.Must(err)
+	if nBytes != N {
+		t.Error("write: ", nBytes)
+	}
+
+	b2.Clear()
+	common.Must2(b2.ReadFullFrom(conn, N))
+	if r := cmp.Diff(b2.Bytes(), b1); r != "" {
+		t.Error(r)
+	}
+}

+ 111 - 70
transport/internet/http/hub.go

@@ -2,11 +2,14 @@ package http
 
 import (
 	"context"
+	gotls "crypto/tls"
 	"io"
 	"net/http"
 	"strings"
 	"time"
 
+	"github.com/quic-go/quic-go"
+	"github.com/quic-go/quic-go/http3"
 	goreality "github.com/xtls/reality"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
@@ -23,10 +26,12 @@ import (
 )
 
 type Listener struct {
-	server  *http.Server
-	handler internet.ConnHandler
-	local   net.Addr
-	config  *Config
+	server   *http.Server
+	h3server *http3.Server
+	handler  internet.ConnHandler
+	local    net.Addr
+	config   *Config
+	isH3     bool
 }
 
 func (l *Listener) Addr() net.Addr {
@@ -34,7 +39,14 @@ func (l *Listener) Addr() net.Addr {
 }
 
 func (l *Listener) Close() error {
-	return l.server.Close()
+	if l.h3server != nil {
+		if err := l.h3server.Close(); err != nil {
+			return err
+		}
+	} else if l.server != nil {
+		return l.server.Close()
+	}
+	return errors.New("listener does not have an HTTP/3 server or h2 server")
 }
 
 type flushWriter struct {
@@ -119,43 +131,33 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
 
 func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
 	httpSettings := streamSettings.ProtocolSettings.(*Config)
-	var listener *Listener
-	if port == net.Port(0) { // unix
-		listener = &Listener{
-			handler: handler,
-			local: &net.UnixAddr{
-				Name: address.Domain(),
-				Net:  "unix",
-			},
-			config: httpSettings,
-		}
-	} else { // tcp
-		listener = &Listener{
-			handler: handler,
-			local: &net.TCPAddr{
-				IP:   address.IP(),
-				Port: int(port),
-			},
-			config: httpSettings,
-		}
-	}
-
-	var server *http.Server
 	config := tls.ConfigFromStreamSettings(streamSettings)
+	var tlsConfig *gotls.Config
 	if config == nil {
-		h2s := &http2.Server{}
-
-		server = &http.Server{
-			Addr:              serial.Concat(address, ":", port),
-			Handler:           h2c.NewHandler(listener, h2s),
-			ReadHeaderTimeout: time.Second * 4,
+		tlsConfig = &gotls.Config{}
+	} else {
+		tlsConfig = config.GetTLSConfig()
+	}
+	isH3 := len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
+	listener := &Listener{
+		handler: handler,
+		config: httpSettings,
+		isH3: isH3,
+	}
+	if port == net.Port(0) { // unix
+		listener.local = &net.UnixAddr{
+			Name: address.Domain(),
+			Net:  "unix",
+		}
+	} else if isH3 { // udp
+		listener.local = &net.UDPAddr{
+			IP:   address.IP(),
+			Port: int(port),
 		}
 	} else {
-		server = &http.Server{
-			Addr:              serial.Concat(address, ":", port),
-			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
-			Handler:           listener,
-			ReadHeaderTimeout: time.Second * 4,
+		listener.local = &net.TCPAddr{
+			IP:   address.IP(),
+			Port: int(port),
 		}
 	}
 
@@ -163,45 +165,84 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 		errors.LogWarning(ctx, "accepting PROXY protocol")
 	}
 
-	listener.server = server
-	go func() {
-		var streamListener net.Listener
-		var err error
-		if port == net.Port(0) { // unix
-			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
-				Name: address.Domain(),
-				Net:  "unix",
-			}, streamSettings.SocketSettings)
-			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to listen on ", address)
-				return
-			}
-		} else { // tcp
-			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
-				IP:   address.IP(),
-				Port: int(port),
-			}, streamSettings.SocketSettings)
-			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port)
-				return
-			}
+	if isH3 {
+		Conn, err := internet.ListenSystemPacket(context.Background(), listener.local, streamSettings.SocketSettings)
+		if err != nil {
+			return nil,  errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
 		}
+		h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
+		if err != nil {
+			return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
+		}
+		errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
 
-		if config == nil {
-			if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
-				streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())
+		listener.h3server = &http3.Server{
+			Handler: listener,
+		}
+		go func() {
+			if err := listener.h3server.ServeListener(h3listener); err != nil {
+				errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
 			}
-			err = server.Serve(streamListener)
-			if err != nil {
-				errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2")
+		}()
+	} else {
+		var server *http.Server
+		if config == nil {
+			h2s := &http2.Server{}
+	
+			server = &http.Server{
+				Addr:              serial.Concat(address, ":", port),
+				Handler:           h2c.NewHandler(listener, h2s),
+				ReadHeaderTimeout: time.Second * 4,
 			}
 		} else {
-			err = server.ServeTLS(streamListener, "", "")
-			if err != nil {
-				errors.LogInfoInner(ctx, err, "stopping serving TLS H2")
+			server = &http.Server{
+				Addr:              serial.Concat(address, ":", port),
+				TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
+				Handler:           listener,
+				ReadHeaderTimeout: time.Second * 4,
 			}
 		}
-	}()
+	
+		listener.server = server
+		go func() {
+			var streamListener net.Listener
+			var err error
+			if port == net.Port(0) { // unix
+				streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
+					Name: address.Domain(),
+					Net:  "unix",
+				}, streamSettings.SocketSettings)
+				if err != nil {
+					errors.LogErrorInner(ctx, err, "failed to listen on ", address)
+					return
+				}
+			} else { // tcp
+				streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
+					IP:   address.IP(),
+					Port: int(port),
+				}, streamSettings.SocketSettings)
+				if err != nil {
+					errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port)
+					return
+				}
+			}
+	
+			if config == nil {
+				if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
+					streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())
+				}
+				err = server.Serve(streamListener)
+				if err != nil {
+					errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2")
+				}
+			} else {
+				err = server.ServeTLS(streamListener, "", "")
+				if err != nil {
+					errors.LogInfoInner(ctx, err, "stopping serving TLS H2")
+				}
+			}
+		}()	
+	}
 
 	return listener, nil
 }

+ 7 - 1
transport/internet/splithttp/hub.go

@@ -365,7 +365,13 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 
 // Addr implements net.Listener.Addr().
 func (ln *Listener) Addr() net.Addr {
-	return ln.listener.Addr()
+	if ln.h3listener != nil {
+		return ln.h3listener.Addr()
+	}
+	if ln.listener != nil {
+		return ln.listener.Addr()
+	}
+	return nil
 }
 
 // Close implements net.Listener.Close().