Преглед на файлове

Add support for ech retry configs

neletor преди 2 месеца
родител
ревизия
30f7ceec79

+ 18 - 8
common/tls/client.go

@@ -2,10 +2,11 @@ package tls
 
 import (
 	"context"
+	"crypto/tls"
+	"errors"
 	"net"
 	"os"
 
-	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/badtls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
@@ -14,7 +15,7 @@ import (
 	aTLS "github.com/sagernet/sing/common/tls"
 )
 
-func NewDialerFromOptions(ctx context.Context, router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
+func NewDialerFromOptions(ctx context.Context, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
 	if !options.Enabled {
 		return dialer, nil
 	}
@@ -79,20 +80,29 @@ func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd
 }
 
 func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
-	return d.dialContext(ctx, destination)
+	return d.dialContext(ctx, destination, true)
 }
 
-func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
+func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr, echRetry bool) (Conn, error) {
 	conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination)
 	if err != nil {
 		return nil, err
 	}
 	tlsConn, err := aTLS.ClientHandshake(ctx, conn, d.config)
-	if err != nil {
-		conn.Close()
-		return nil, err
+	if err == nil {
+		return tlsConn, nil
 	}
-	return tlsConn, nil
+	conn.Close()
+	if echRetry {
+		var echErr *tls.ECHRejectionError
+		if errors.As(err, &echErr) && len(echErr.RetryConfigList) > 0 {
+			if echConfig, isECH := d.config.(ECHCapableConfig); isECH {
+				echConfig.SetECHConfigList(echErr.RetryConfigList)
+			}
+		}
+		return d.dialContext(ctx, destination, false)
+	}
+	return nil, err
 }
 
 func (d *defaultDialer) Upstream() any {

+ 3 - 8
dns/transport/tls.go

@@ -30,7 +30,7 @@ func RegisterTLS(registry *dns.TransportRegistry) {
 type TLSTransport struct {
 	dns.TransportAdapter
 	logger      logger.ContextLogger
-	dialer      N.Dialer
+	dialer      tls.Dialer
 	serverAddr  M.Socksaddr
 	tlsConfig   tls.Config
 	access      sync.Mutex
@@ -67,7 +67,7 @@ func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer
 	return &TLSTransport{
 		TransportAdapter: adapter,
 		logger:           logger,
-		dialer:           dialer,
+		dialer:           tls.NewDialer(dialer, tlsConfig),
 		serverAddr:       serverAddr,
 		tlsConfig:        tlsConfig,
 	}
@@ -100,15 +100,10 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
 			return response, nil
 		}
 	}
-	tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
+	tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
 	if err != nil {
 		return nil, err
 	}
-	tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig)
-	if err != nil {
-		tcpConn.Close()
-		return nil, err
-	}
 	return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
 }
 

+ 4 - 12
protocol/anytls/outbound.go

@@ -27,7 +27,7 @@ func RegisterOutbound(registry *outbound.Registry) {
 
 type Outbound struct {
 	outbound.Adapter
-	dialer    N.Dialer
+	dialer    tls.Dialer
 	server    M.Socksaddr
 	tlsConfig tls.Config
 	client    *anytls.Client
@@ -66,7 +66,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 	if err != nil {
 		return nil, err
 	}
-	outbound.dialer = outboundDialer
+
+	outbound.dialer = tls.NewDialer(outboundDialer, tlsConfig)
 
 	client, err := anytls.NewClient(ctx, anytls.ClientConfig{
 		Password:                 options.Password,
@@ -99,16 +100,7 @@ func (d anytlsDialer) ListenPacket(ctx context.Context, destination M.Socksaddr)
 }
 
 func (h *Outbound) dialOut(ctx context.Context) (net.Conn, error) {
-	conn, err := h.dialer.DialContext(ctx, N.NetworkTCP, h.server)
-	if err != nil {
-		return nil, err
-	}
-	tlsConn, err := tls.ClientHandshake(ctx, conn, h.tlsConfig)
-	if err != nil {
-		common.Close(tlsConn, conn)
-		return nil, err
-	}
-	return tlsConn, nil
+	return h.dialer.DialTLSContext(ctx, h.server)
 }
 
 func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

+ 1 - 1
protocol/http/outbound.go

@@ -34,7 +34,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 	if err != nil {
 		return nil, err
 	}
-	detour, err := tls.NewDialerFromOptions(ctx, router, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS))
+	detour, err := tls.NewDialerFromOptions(ctx, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS))
 	if err != nil {
 		return nil, err
 	}

+ 4 - 3
protocol/trojan/outbound.go

@@ -34,6 +34,7 @@ type Outbound struct {
 	key             [56]byte
 	multiplexDialer *mux.Client
 	tlsConfig       tls.Config
+	tlsDialer       tls.Dialer
 	transport       adapter.V2RayClientTransport
 }
 
@@ -54,6 +55,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 		if err != nil {
 			return nil, err
 		}
+		outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig)
 	}
 	if options.Transport != nil {
 		outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
@@ -121,11 +123,10 @@ func (h *trojanDialer) DialContext(ctx context.Context, network string, destinat
 	var err error
 	if h.transport != nil {
 		conn, err = h.transport.DialContext(ctx)
+	} else if h.tlsDialer != nil {
+		conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr)
 	} else {
 		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
-		if err == nil && h.tlsConfig != nil {
-			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
-		}
 	}
 	if err != nil {
 		common.Close(conn)

+ 6 - 6
protocol/vless/outbound.go

@@ -35,6 +35,7 @@ type Outbound struct {
 	serverAddr      M.Socksaddr
 	multiplexDialer *mux.Client
 	tlsConfig       tls.Config
+	tlsDialer       tls.Dialer
 	transport       adapter.V2RayClientTransport
 	packetAddr      bool
 	xudp            bool
@@ -56,6 +57,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 		if err != nil {
 			return nil, err
 		}
+		outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig)
 	}
 	if options.Transport != nil {
 		outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
@@ -140,11 +142,10 @@ func (h *vlessDialer) DialContext(ctx context.Context, network string, destinati
 	var err error
 	if h.transport != nil {
 		conn, err = h.transport.DialContext(ctx)
+	} else if h.tlsDialer != nil {
+		conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr)
 	} else {
 		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
-		if err == nil && h.tlsConfig != nil {
-			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
-		}
 	}
 	if err != nil {
 		return nil, err
@@ -183,11 +184,10 @@ func (h *vlessDialer) ListenPacket(ctx context.Context, destination M.Socksaddr)
 	var err error
 	if h.transport != nil {
 		conn, err = h.transport.DialContext(ctx)
+	} else if h.tlsDialer != nil {
+		conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr)
 	} else {
 		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
-		if err == nil && h.tlsConfig != nil {
-			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
-		}
 	}
 	if err != nil {
 		common.Close(conn)

+ 6 - 6
protocol/vmess/outbound.go

@@ -35,6 +35,7 @@ type Outbound struct {
 	serverAddr      M.Socksaddr
 	multiplexDialer *mux.Client
 	tlsConfig       tls.Config
+	tlsDialer       tls.Dialer
 	transport       adapter.V2RayClientTransport
 	packetAddr      bool
 	xudp            bool
@@ -56,6 +57,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
 		if err != nil {
 			return nil, err
 		}
+		outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig)
 	}
 	if options.Transport != nil {
 		outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
@@ -154,11 +156,10 @@ func (h *vmessDialer) DialContext(ctx context.Context, network string, destinati
 	var err error
 	if h.transport != nil {
 		conn, err = h.transport.DialContext(ctx)
+	} else if h.tlsDialer != nil {
+		conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr)
 	} else {
 		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
-		if err == nil && h.tlsConfig != nil {
-			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
-		}
 	}
 	if err != nil {
 		common.Close(conn)
@@ -182,11 +183,10 @@ func (h *vmessDialer) ListenPacket(ctx context.Context, destination M.Socksaddr)
 	var err error
 	if h.transport != nil {
 		conn, err = h.transport.DialContext(ctx)
+	} else if h.tlsDialer != nil {
+		conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr)
 	} else {
 		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
-		if err == nil && h.tlsConfig != nil {
-			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
-		}
 	}
 	if err != nil {
 		return nil, err

+ 2 - 8
transport/v2raygrpclite/client.go

@@ -29,7 +29,6 @@ var defaultClientHeader = http.Header{
 
 type Client struct {
 	ctx        context.Context
-	dialer     N.Dialer
 	serverAddr M.Socksaddr
 	transport  *http2.Transport
 	options    option.V2RayGRPCOptions
@@ -46,7 +45,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	}
 	client := &Client{
 		ctx:        ctx,
-		dialer:     dialer,
 		serverAddr: serverAddr,
 		options:    options,
 		transport: &http2.Transport{
@@ -62,7 +60,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 		},
 		host: host,
 	}
-
 	if tlsConfig == nil {
 		client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
 			return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
@@ -71,12 +68,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 		if len(tlsConfig.NextProtos()) == 0 {
 			tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
 		}
+		tlsDialer := tls.NewDialer(dialer, tlsConfig)
 		client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
-			conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-			if err != nil {
-				return nil, err
-			}
-			return tls.ClientHandshake(ctx, conn, tlsConfig)
+			return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr))
 		}
 	}
 

+ 2 - 5
transport/v2rayhttp/client.go

@@ -47,15 +47,12 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 		if len(tlsConfig.NextProtos()) == 0 {
 			tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
 		}
+		tlsDialer := tls.NewDialer(dialer, tlsConfig)
 		transport = &http2.Transport{
 			ReadIdleTimeout: time.Duration(options.IdleTimeout),
 			PingTimeout:     time.Duration(options.PingTimeout),
 			DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
-				conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-				if err != nil {
-					return nil, err
-				}
-				return tls.ClientHandshake(ctx, conn, tlsConfig)
+				return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr))
 			},
 		}
 	}

+ 1 - 8
transport/v2rayhttpupgrade/client.go

@@ -23,7 +23,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
 
 type Client struct {
 	dialer     N.Dialer
-	tlsConfig  tls.Config
 	serverAddr M.Socksaddr
 	requestURL url.URL
 	headers    http.Header
@@ -35,6 +34,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 		if len(tlsConfig.NextProtos()) == 0 {
 			tlsConfig.SetNextProtos([]string{"http/1.1"})
 		}
+		dialer = tls.NewDialer(dialer, tlsConfig)
 	}
 	var host string
 	if options.Host != "" {
@@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	}
 	return &Client{
 		dialer:     dialer,
-		tlsConfig:  tlsConfig,
 		serverAddr: serverAddr,
 		requestURL: requestURL,
 		headers:    headers,
@@ -78,12 +77,6 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
 	if err != nil {
 		return nil, err
 	}
-	if c.tlsConfig != nil {
-		conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
-		if err != nil {
-			return nil, err
-		}
-	}
 	request := &http.Request{
 		Method: http.MethodGet,
 		URL:    &c.requestURL,

+ 1 - 8
transport/v2raywebsocket/client.go

@@ -26,7 +26,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
 
 type Client struct {
 	dialer              N.Dialer
-	tlsConfig           tls.Config
 	serverAddr          M.Socksaddr
 	requestURL          url.URL
 	headers             http.Header
@@ -39,6 +38,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 		if len(tlsConfig.NextProtos()) == 0 {
 			tlsConfig.SetNextProtos([]string{"http/1.1"})
 		}
+		dialer = tls.NewDialer(dialer, tlsConfig)
 	}
 	var requestURL url.URL
 	if tlsConfig == nil {
@@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
 	}
 	return &Client{
 		dialer,
-		tlsConfig,
 		serverAddr,
 		requestURL,
 		headers,
@@ -79,12 +78,6 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h
 	if err != nil {
 		return nil, err
 	}
-	if c.tlsConfig != nil {
-		conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
-		if err != nil {
-			return nil, err
-		}
-	}
 	var deadlineConn net.Conn
 	if deadline.NeedAdditionalReadDeadline(conn) {
 		deadlineConn = deadline.NewConn(conn)