Parcourir la source

Improve HTTPS DNS transport

世界 il y a 1 semaine
Parent
commit
41b30c91d9
3 fichiers modifiés avec 115 ajouts et 41 suppressions
  1. 31 9
      common/tls/client.go
  2. 4 32
      dns/transport/https.go
  3. 80 0
      dns/transport/https_transport.go

+ 31 - 9
common/tls/client.go

@@ -53,26 +53,48 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e
 	return tlsConn, nil
 }
 
-type Dialer struct {
+type Dialer interface {
+	N.Dialer
+	DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error)
+}
+
+type defaultDialer struct {
 	dialer N.Dialer
 	config Config
 }
 
-func NewDialer(dialer N.Dialer, config Config) N.Dialer {
-	return &Dialer{dialer, config}
+func NewDialer(dialer N.Dialer, config Config) Dialer {
+	return &defaultDialer{dialer, config}
 }
 
-func (d *Dialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
-	if network != N.NetworkTCP {
+func (d *defaultDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	if N.NetworkName(network) != N.NetworkTCP {
 		return nil, os.ErrInvalid
 	}
-	conn, err := d.dialer.DialContext(ctx, network, destination)
+	return d.DialTLSContext(ctx, destination)
+}
+
+func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	return nil, os.ErrInvalid
+}
+
+func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
+	return d.dialContext(ctx, destination)
+}
+
+func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
+	conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination)
+	if err != nil {
+		return nil, err
+	}
+	tlsConn, err := aTLS.ClientHandshake(ctx, conn, d.config)
 	if err != nil {
+		conn.Close()
 		return nil, err
 	}
-	return ClientHandshake(ctx, conn, d.config)
+	return tlsConn, nil
 }
 
-func (d *Dialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
-	return nil, os.ErrInvalid
+func (d *defaultDialer) Upstream() any {
+	return d.dialer
 }

+ 4 - 32
dns/transport/https.go

@@ -25,7 +25,6 @@ import (
 	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
-	aTLS "github.com/sagernet/sing/common/tls"
 	sHTTP "github.com/sagernet/sing/protocol/http"
 
 	mDNS "github.com/miekg/dns"
@@ -47,7 +46,7 @@ type HTTPSTransport struct {
 	destination      *url.URL
 	headers          http.Header
 	transportAccess  sync.Mutex
-	transport        *http.Transport
+	transport        *HTTPSTransportWrapper
 	transportResetAt time.Time
 }
 
@@ -62,11 +61,8 @@ func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options
 	if err != nil {
 		return nil, err
 	}
-	if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) {
-		tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS))
-	}
-	if !common.Contains(tlsConfig.NextProtos(), "http/1.1") {
-		tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1"))
+	if len(tlsConfig.NextProtos()) == 0 {
+		tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
 	}
 	headers := options.Headers.Build()
 	host := headers.Get("Host")
@@ -124,37 +120,13 @@ func NewHTTPSRaw(
 	serverAddr M.Socksaddr,
 	tlsConfig tls.Config,
 ) *HTTPSTransport {
-	var transport *http.Transport
-	if tlsConfig != nil {
-		transport = &http.Transport{
-			ForceAttemptHTTP2: true,
-			DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr)
-				if hErr != nil {
-					return nil, hErr
-				}
-				tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig)
-				if hErr != nil {
-					tcpConn.Close()
-					return nil, hErr
-				}
-				return tlsConn, nil
-			},
-		}
-	} else {
-		transport = &http.Transport{
-			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				return dialer.DialContext(ctx, network, serverAddr)
-			},
-		}
-	}
 	return &HTTPSTransport{
 		TransportAdapter: adapter,
 		logger:           logger,
 		dialer:           dialer,
 		destination:      destination,
 		headers:          headers,
-		transport:        transport,
+		transport:        NewHTTPSTransportWrapper(tls.NewDialer(dialer, tlsConfig), serverAddr),
 	}
 }
 

+ 80 - 0
dns/transport/https_transport.go

@@ -0,0 +1,80 @@
+package transport
+
+import (
+	"context"
+	"errors"
+	"net"
+	"net/http"
+	"sync/atomic"
+
+	"github.com/sagernet/sing-box/common/tls"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+
+	"golang.org/x/net/http2"
+)
+
+var errFallback = E.New("fallback to HTTP/1.1")
+
+type HTTPSTransportWrapper struct {
+	http2Transport *http2.Transport
+	httpTransport  *http.Transport
+	fallback       *atomic.Bool
+}
+
+func NewHTTPSTransportWrapper(dialer tls.Dialer, serverAddr M.Socksaddr) *HTTPSTransportWrapper {
+	var fallback atomic.Bool
+	return &HTTPSTransportWrapper{
+		http2Transport: &http2.Transport{
+			DialTLSContext: func(ctx context.Context, _, _ string, _ *tls.STDConfig) (net.Conn, error) {
+				tlsConn, err := dialer.DialTLSContext(ctx, serverAddr)
+				if err != nil {
+					return nil, err
+				}
+				state := tlsConn.ConnectionState()
+				if state.NegotiatedProtocol == http2.NextProtoTLS {
+					return tlsConn, nil
+				}
+				tlsConn.Close()
+				fallback.Store(true)
+				return nil, errFallback
+			},
+		},
+		httpTransport: &http.Transport{
+			DialTLSContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+				return dialer.DialTLSContext(ctx, serverAddr)
+			},
+		},
+		fallback: &fallback,
+	}
+}
+
+func (h *HTTPSTransportWrapper) RoundTrip(request *http.Request) (*http.Response, error) {
+	if h.fallback.Load() {
+		return h.httpTransport.RoundTrip(request)
+	} else {
+		response, err := h.http2Transport.RoundTrip(request)
+		if err != nil {
+			if errors.Is(err, errFallback) {
+				return h.httpTransport.RoundTrip(request)
+			}
+			return nil, err
+		}
+		return response, nil
+	}
+}
+
+func (h *HTTPSTransportWrapper) CloseIdleConnections() {
+	h.http2Transport.CloseIdleConnections()
+	h.httpTransport.CloseIdleConnections()
+}
+
+func (h *HTTPSTransportWrapper) Clone() *HTTPSTransportWrapper {
+	return &HTTPSTransportWrapper{
+		httpTransport: h.httpTransport,
+		http2Transport: &http2.Transport{
+			DialTLSContext: h.http2Transport.DialTLSContext,
+		},
+		fallback: h.fallback,
+	}
+}