Browse Source

Add ECH support for uTLS

Restia-Ashbell 4 months ago
parent
commit
c0588c30d7

+ 16 - 16
common/tls/ech.go

@@ -25,7 +25,7 @@ import (
 	"golang.org/x/crypto/cryptobyte"
 )
 
-func parseECHClientConfig(ctx context.Context, stdConfig *STDClientConfig, options option.OutboundTLSOptions) (Config, error) {
+func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, options option.OutboundTLSOptions) (Config, error) {
 	var echConfig []byte
 	if len(options.ECH.Config) > 0 {
 		echConfig = []byte(strings.Join(options.ECH.Config, "\n"))
@@ -45,12 +45,12 @@ func parseECHClientConfig(ctx context.Context, stdConfig *STDClientConfig, optio
 		if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 {
 			return nil, E.New("invalid ECH configs pem")
 		}
-		stdConfig.config.EncryptedClientHelloConfigList = block.Bytes
-		return stdConfig, nil
+		clientConfig.SetECHConfigList(block.Bytes)
+		return clientConfig, nil
 	} else {
-		return &STDECHClientConfig{
-			STDClientConfig: stdConfig,
-			dnsRouter:       service.FromContext[adapter.DNSRouter](ctx),
+		return &ECHClientConfig{
+			ECHCapableConfig: clientConfig,
+			dnsRouter:        service.FromContext[adapter.DNSRouter](ctx),
 		}, nil
 	}
 }
@@ -102,15 +102,15 @@ func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
 	return nil
 }
 
-type STDECHClientConfig struct {
-	*STDClientConfig
+type ECHClientConfig struct {
+	ECHCapableConfig
 	access     sync.Mutex
 	dnsRouter  adapter.DNSRouter
 	lastTTL    time.Duration
 	lastUpdate time.Time
 }
 
-func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
+func (s *ECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
 	tlsConn, err := s.fetchAndHandshake(ctx, conn)
 	if err != nil {
 		return nil, err
@@ -122,17 +122,17 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn)
 	return tlsConn, nil
 }
 
-func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
+func (s *ECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
 	s.access.Lock()
 	defer s.access.Unlock()
-	if len(s.config.EncryptedClientHelloConfigList) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL {
+	if len(s.ECHConfigList()) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL {
 		message := &mDNS.Msg{
 			MsgHdr: mDNS.MsgHdr{
 				RecursionDesired: true,
 			},
 			Question: []mDNS.Question{
 				{
-					Name:   mDNS.Fqdn(s.config.ServerName),
+					Name:   mDNS.Fqdn(s.ServerName()),
 					Qtype:  mDNS.TypeHTTPS,
 					Qclass: mDNS.ClassINET,
 				},
@@ -157,21 +157,21 @@ func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Con
 						}
 						s.lastTTL = time.Duration(rr.Header().Ttl) * time.Second
 						s.lastUpdate = time.Now()
-						s.config.EncryptedClientHelloConfigList = echConfigList
+						s.SetECHConfigList(echConfigList)
 						break match
 					}
 				}
 			}
 		}
-		if len(s.config.EncryptedClientHelloConfigList) == 0 {
+		if len(s.ECHConfigList()) == 0 {
 			return nil, E.New("no ECH config found in DNS records")
 		}
 	}
 	return s.Client(conn)
 }
 
-func (s *STDECHClientConfig) Clone() Config {
-	return &STDECHClientConfig{STDClientConfig: s.STDClientConfig.Clone().(*STDClientConfig), dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate}
+func (s *ECHClientConfig) Clone() Config {
+	return &ECHClientConfig{ECHCapableConfig: s.ECHCapableConfig.Clone().(ECHCapableConfig), dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate}
 }
 
 func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {

+ 6 - 0
common/tls/ech_keygen.go → common/tls/ech_shared.go

@@ -11,6 +11,12 @@ import (
 	"github.com/cloudflare/circl/kem"
 )
 
+type ECHCapableConfig interface {
+	Config
+	ECHConfigList() []byte
+	SetECHConfigList([]byte)
+}
+
 func ECHKeygenDefault(serverName string) (configPem string, keyPem string, err error) {
 	cipherSuites := []echCipherSuite{
 		{

+ 1 - 1
common/tls/ech_stub.go

@@ -10,7 +10,7 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 )
 
-func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions, tlsConfig *tls.Config) (Config, error) {
+func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, options option.OutboundTLSOptions) (Config, error) {
 	return nil, E.New("ECH requires go1.24, please recompile your binary.")
 }
 

+ 1 - 1
common/tls/reality_client.go

@@ -74,7 +74,7 @@ func NewRealityClient(ctx context.Context, serverAddress string, options option.
 	if decodedLen > 8 {
 		return nil, E.New("invalid short_id")
 	}
-	return &RealityClientConfig{ctx, uClient, publicKey, shortID}, nil
+	return &RealityClientConfig{ctx, uClient.(*UTLSClientConfig), publicKey, shortID}, nil
 }
 
 func (e *RealityClientConfig) ServerName() string {

+ 25 - 19
common/tls/std_client.go

@@ -24,35 +24,43 @@ type STDClientConfig struct {
 	recordFragment        bool
 }
 
-func (s *STDClientConfig) ServerName() string {
-	return s.config.ServerName
+func (c *STDClientConfig) ServerName() string {
+	return c.config.ServerName
 }
 
-func (s *STDClientConfig) SetServerName(serverName string) {
-	s.config.ServerName = serverName
+func (c *STDClientConfig) SetServerName(serverName string) {
+	c.config.ServerName = serverName
 }
 
-func (s *STDClientConfig) NextProtos() []string {
-	return s.config.NextProtos
+func (c *STDClientConfig) NextProtos() []string {
+	return c.config.NextProtos
 }
 
-func (s *STDClientConfig) SetNextProtos(nextProto []string) {
-	s.config.NextProtos = nextProto
+func (c *STDClientConfig) SetNextProtos(nextProto []string) {
+	c.config.NextProtos = nextProto
 }
 
-func (s *STDClientConfig) Config() (*STDConfig, error) {
-	return s.config, nil
+func (c *STDClientConfig) Config() (*STDConfig, error) {
+	return c.config, nil
 }
 
-func (s *STDClientConfig) Client(conn net.Conn) (Conn, error) {
-	if s.recordFragment {
-		conn = tf.NewConn(conn, s.ctx, s.fragment, s.recordFragment, s.fragmentFallbackDelay)
+func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) {
+	if c.recordFragment {
+		conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
 	}
-	return tls.Client(conn, s.config), nil
+	return tls.Client(conn, c.config), nil
 }
 
-func (s *STDClientConfig) Clone() Config {
-	return &STDClientConfig{s.ctx, s.config.Clone(), s.fragment, s.fragmentFallbackDelay, s.recordFragment}
+func (c *STDClientConfig) Clone() Config {
+	return &STDClientConfig{c.ctx, c.config.Clone(), c.fragment, c.fragmentFallbackDelay, c.recordFragment}
+}
+
+func (c *STDClientConfig) ECHConfigList() []byte {
+	return c.config.EncryptedClientHelloConfigList
+}
+
+func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte) {
+	c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList
 }
 
 func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
@@ -69,9 +77,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
 	var tlsConfig tls.Config
 	tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
 	tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
-	if options.DisableSNI {
-		tlsConfig.ServerName = "127.0.0.1"
-	} else {
+	if !options.DisableSNI {
 		tlsConfig.ServerName = serverName
 	}
 	if options.Insecure {

+ 38 - 29
common/tls/utls_client.go

@@ -8,7 +8,6 @@ import (
 	"crypto/x509"
 	"math/rand"
 	"net"
-	"net/netip"
 	"os"
 	"strings"
 	"time"
@@ -32,46 +31,54 @@ type UTLSClientConfig struct {
 	recordFragment        bool
 }
 
-func (e *UTLSClientConfig) ServerName() string {
-	return e.config.ServerName
+func (c *UTLSClientConfig) ServerName() string {
+	return c.config.ServerName
 }
 
-func (e *UTLSClientConfig) SetServerName(serverName string) {
-	e.config.ServerName = serverName
+func (c *UTLSClientConfig) SetServerName(serverName string) {
+	c.config.ServerName = serverName
 }
 
-func (e *UTLSClientConfig) NextProtos() []string {
-	return e.config.NextProtos
+func (c *UTLSClientConfig) NextProtos() []string {
+	return c.config.NextProtos
 }
 
-func (e *UTLSClientConfig) SetNextProtos(nextProto []string) {
+func (c *UTLSClientConfig) SetNextProtos(nextProto []string) {
 	if len(nextProto) == 1 && nextProto[0] == http2.NextProtoTLS {
 		nextProto = append(nextProto, "http/1.1")
 	}
-	e.config.NextProtos = nextProto
+	c.config.NextProtos = nextProto
 }
 
-func (e *UTLSClientConfig) Config() (*STDConfig, error) {
+func (c *UTLSClientConfig) Config() (*STDConfig, error) {
 	return nil, E.New("unsupported usage for uTLS")
 }
 
-func (e *UTLSClientConfig) Client(conn net.Conn) (Conn, error) {
-	if e.recordFragment {
-		conn = tf.NewConn(conn, e.ctx, e.fragment, e.recordFragment, e.fragmentFallbackDelay)
+func (c *UTLSClientConfig) Client(conn net.Conn) (Conn, error) {
+	if c.recordFragment {
+		conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
 	}
-	return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, e.config.Clone(), e.id)}, e.config.NextProtos}, nil
+	return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, c.config.Clone(), c.id)}, c.config.NextProtos}, nil
 }
 
-func (e *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) {
-	e.config.SessionIDGenerator = generator
+func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) {
+	c.config.SessionIDGenerator = generator
 }
 
-func (e *UTLSClientConfig) Clone() Config {
+func (c *UTLSClientConfig) Clone() Config {
 	return &UTLSClientConfig{
-		e.ctx, e.config.Clone(), e.id, e.fragment, e.fragmentFallbackDelay, e.recordFragment,
+		c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment,
 	}
 }
 
+func (c *UTLSClientConfig) ECHConfigList() []byte {
+	return c.config.EncryptedClientHelloConfigList
+}
+
+func (c *UTLSClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte) {
+	c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList
+}
+
 type utlsConnWrapper struct {
 	*utls.UConn
 }
@@ -124,14 +131,12 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
 	return c.UConn.HandshakeContext(ctx)
 }
 
-func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) {
+func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
 	var serverName string
 	if options.ServerName != "" {
 		serverName = options.ServerName
 	} else if serverAddress != "" {
-		if _, err := netip.ParseAddr(serverName); err != nil {
-			serverName = serverAddress
-		}
+		serverName = serverAddress
 	}
 	if serverName == "" && !options.Insecure {
 		return nil, E.New("missing server_name or insecure=true")
@@ -140,11 +145,7 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
 	var tlsConfig utls.Config
 	tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
 	tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
-	if options.DisableSNI {
-		tlsConfig.ServerName = "127.0.0.1"
-	} else {
-		tlsConfig.ServerName = serverName
-	}
+	tlsConfig.ServerName = serverName
 	if options.Insecure {
 		tlsConfig.InsecureSkipVerify = options.Insecure
 	} else if options.DisableSNI {
@@ -200,7 +201,15 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
 	if err != nil {
 		return nil, err
 	}
-	return &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}, nil
+	uConfig := &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
+	if options.ECH != nil && options.ECH.Enabled {
+		if options.Reality != nil && options.Reality.Enabled {
+			return nil, E.New("Reality is conflict with ECH")
+		}
+		return parseECHClientConfig(ctx, uConfig, options)
+	} else {
+		return uConfig, nil
+	}
 }
 
 var (
@@ -228,7 +237,7 @@ func init() {
 
 func uTLSClientHelloID(name string) (utls.ClientHelloID, error) {
 	switch name {
-	case "chrome_psk", "chrome_psk_shuffle", "chrome_padding_psk_shuffle", "chrome_pq":
+	case "chrome_psk", "chrome_psk_shuffle", "chrome_padding_psk_shuffle", "chrome_pq", "chrome_pq_psk":
 		fallthrough
 	case "chrome", "":
 		return utls.HelloChrome_Auto, nil