Просмотр исходного кода

SplitHTTP: Server supports HTTP/3 (#3554)

Co-authored-by: mmmray <[email protected]>
ll11l1lIllIl1lll 1 год назад
Родитель
Сommit
a0040f13dd
2 измененных файлов с 113 добавлено и 33 удалено
  1. 73 33
      transport/internet/splithttp/hub.go
  2. 40 0
      transport/internet/splithttp/splithttp_test.go

+ 73 - 33
transport/internet/splithttp/hub.go

@@ -11,6 +11,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/quic-go/quic-go"
+	"github.com/quic-go/quic-go/http3"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
@@ -233,10 +235,13 @@ func (c *httpResponseBodyWriter) Close() error {
 
 type Listener struct {
 	sync.Mutex
-	server   http.Server
-	listener net.Listener
-	config   *Config
-	addConn  internet.ConnHandler
+	server     http.Server
+	h3server   *http3.Server
+	listener   net.Listener
+	h3listener *quic.EarlyListener
+	config     *Config
+	addConn    internet.ConnHandler
+	isH3       bool
 }
 
 func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
@@ -253,6 +258,17 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 	var listener net.Listener
 	var err error
 	var localAddr = gonet.TCPAddr{}
+	handler := &requestHandler{
+		host:      shSettings.Host,
+		path:      shSettings.GetNormalizedPath(),
+		ln:        l,
+		sessionMu: &sync.Mutex{},
+		sessions:  sync.Map{},
+		localAddr: localAddr,
+	}
+	tlsConfig := getTLSConfig(streamSettings)
+	l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
+
 
 	if port == net.Port(0) { // unix
 		listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
@@ -263,6 +279,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
 		}
 		errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
+	} else if l.isH3 { // quic
+		Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
+			IP:   address.IP(),
+			Port: int(port),
+		}, streamSettings.SocketSettings)
+		if err != nil {
+			return nil,  errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
+		}
+		h3listener, err := quic.ListenEarly(Conn,tlsConfig, nil)
+		if err != nil {
+			return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
+		}
+		l.h3listener = h3listener
+		errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
+
+		l.h3server = &http3.Server{
+			Handler: handler,
+		}
+		go func() {
+			if err := l.h3server.ServeListener(l.h3listener); err != nil {
+				errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
+			}
+		}()
 	} else { // tcp
 		localAddr = gonet.TCPAddr{
 			IP:   address.IP(),
@@ -275,41 +314,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 		if err != nil {
 			return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
 		}
+		l.listener = listener
 		errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
-	}
 
+		// h2cHandler can handle both plaintext HTTP/1.1 and h2c
+		h2cHandler := h2c.NewHandler(handler, &http2.Server{})
+		l.server = http.Server{
+			Handler:           h2cHandler,
+			ReadHeaderTimeout: time.Second * 4,
+			MaxHeaderBytes:    8192,
+		}
+		go func() {
+			if err := l.server.Serve(l.listener); err != nil {
+				errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
+			}
+		}()
+	}
+	l.listener = listener
 	if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
 		if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
 			listener = tls.NewListener(listener, tlsConfig)
 		}
 	}
 
-	handler := &requestHandler{
-		host:      shSettings.Host,
-		path:      shSettings.GetNormalizedPath(),
-		ln:        l,
-		sessionMu: &sync.Mutex{},
-		sessions:  sync.Map{},
-		localAddr: localAddr,
-	}
-
-	// h2cHandler can handle both plaintext HTTP/1.1 and h2c
-	h2cHandler := h2c.NewHandler(handler, &http2.Server{})
-
-	l.listener = listener
-
-	l.server = http.Server{
-		Handler:           h2cHandler,
-		ReadHeaderTimeout: time.Second * 4,
-		MaxHeaderBytes:    8192,
-	}
-
-	go func() {
-		if err := l.server.Serve(l.listener); err != nil {
-			errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
-		}
-	}()
-
 	return l, err
 }
 
@@ -320,9 +347,22 @@ func (ln *Listener) Addr() net.Addr {
 
 // Close implements net.Listener.Close().
 func (ln *Listener) Close() error {
-	return ln.listener.Close()
+	if ln.h3server != nil {
+		if err := ln.h3server.Close(); err != nil {
+			return err
+		}
+	} else if ln.listener != nil {
+		return ln.listener.Close()
+	}
+	return errors.New("listener does not have an HTTP/3 server or a net.listener")
+}
+func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
+	config := v2tls.ConfigFromStreamSettings(streamSettings)
+	if config == nil {
+		return &tls.Config{}
+	}
+	return config.GetTLSConfig()
 }
-
 func init() {
 	common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
 }

+ 40 - 0
transport/internet/splithttp/splithttp_test.go

@@ -14,6 +14,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol/tls/cert"
 	"github.com/xtls/xray-core/testing/servers/tcp"
+	"github.com/xtls/xray-core/testing/servers/udp"
 	"github.com/xtls/xray-core/transport/internet"
 	. "github.com/xtls/xray-core/transport/internet/splithttp"
 	"github.com/xtls/xray-core/transport/internet/stat"
@@ -204,3 +205,42 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
 		t.Error("Expected h2 but got:", resp.ProtoMajor)
 	}
 }
+
+func Test_listenSHAndDial_QUIC(t *testing.T) {
+	if runtime.GOARCH == "arm64" {
+		return
+	}
+
+	listenPort := udp.PickPort()
+
+	start := time.Now()
+
+	streamSettings := &internet.MemoryStreamConfig{
+		ProtocolName: "splithttp",
+		ProtocolSettings: &Config{
+			Path: "shs",
+		},
+		SecurityType: "tls",
+		SecuritySettings: &tls.Config{
+			AllowInsecure: true,
+			Certificate:   []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
+			NextProtocol:  []string{"h3"},
+		},
+	}
+	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
+		go func() {
+			_ = conn.Close()
+		}()
+	})
+	common.Must(err)
+	defer listen.Close()
+
+	conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
+	common.Must(err)
+	_ = conn.Close()
+
+	end := time.Now()
+	if !end.Before(start.Add(time.Second * 5)) {
+		t.Error("end: ", end, " start: ", start)
+	}
+}