Browse Source

TLS ECH client: `echForceQuery` "full" / "half" / "none" (default) (#4973)

https://github.com/XTLS/Xray-core/pull/4971#issuecomment-3148113203
风扇滑翔翼 3 months ago
parent
commit
7cbf5b004c

+ 7 - 1
infra/conf/transport_internet.go

@@ -414,7 +414,7 @@ type TLSConfig struct {
 	VerifyPeerCertInNames                []string         `json:"verifyPeerCertInNames"`
 	ECHServerKeys                        string           `json:"echServerKeys"`
 	ECHConfigList                        string           `json:"echConfigList"`
-	ECHForceQuery                        bool             `json:"echForceQuery"`
+	ECHForceQuery                        string           `json:"echForceQuery"`
 	ECHSocketSettings                    *SocketConfig    `json:"echSockopt"`
 }
 
@@ -494,6 +494,12 @@ func (c *TLSConfig) Build() (proto.Message, error) {
 		}
 		config.EchServerKeys = EchPrivateKey
 	}
+	switch c.ECHForceQuery {
+	case "none", "half", "full", "":
+		config.EchForceQuery = c.ECHForceQuery
+	default:
+		return nil, errors.New(`invalid "echForceQuery": `, c.ECHForceQuery)
+	}
 	config.EchForceQuery = c.ECHForceQuery
 	config.EchConfigList = c.ECHConfigList
 	if c.ECHSocketSettings != nil {

+ 1 - 2
transport/internet/tls/config.go

@@ -8,7 +8,6 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"encoding/base64"
-	"github.com/xtls/xray-core/features/dns"
 	"os"
 	"slices"
 	"strings"
@@ -451,7 +450,7 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 	if len(c.EchConfigList) > 0 || len(c.EchServerKeys) > 0 {
 		err := ApplyECH(c, config)
 		if err != nil {
-			if c.EchForceQuery || errors.Cause(err) != dns.ErrEmptyResponse {
+			if c.EchForceQuery == "full" {
 				errors.LogError(context.Background(), err)
 			} else {
 				errors.LogInfo(context.Background(), err)

+ 4 - 4
transport/internet/tls/config.pb.go

@@ -220,7 +220,7 @@ type Config struct {
 	VerifyPeerCertInNames []string               `protobuf:"bytes,17,rep,name=verify_peer_cert_in_names,json=verifyPeerCertInNames,proto3" json:"verify_peer_cert_in_names,omitempty"`
 	EchServerKeys         []byte                 `protobuf:"bytes,18,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"`
 	EchConfigList         string                 `protobuf:"bytes,19,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"`
-	EchForceQuery         bool                   `protobuf:"varint,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"`
+	EchForceQuery         string                 `protobuf:"bytes,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"`
 	EchSocketSettings     *internet.SocketConfig `protobuf:"bytes,21,opt,name=ech_socket_settings,json=echSocketSettings,proto3" json:"ech_socket_settings,omitempty"`
 }
 
@@ -380,11 +380,11 @@ func (x *Config) GetEchConfigList() string {
 	return ""
 }
 
-func (x *Config) GetEchForceQuery() bool {
+func (x *Config) GetEchForceQuery() string {
 	if x != nil {
 		return x.EchForceQuery
 	}
-	return false
+	return ""
 }
 
 func (x *Config) GetEchSocketSettings() *internet.SocketConfig {
@@ -483,7 +483,7 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{
 	0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73,
 	0x74, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66,
 	0x69, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x66, 0x6f,
-	0x72, 0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x52,
+	0x72, 0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52,
 	0x0d, 0x65, 0x63, 0x68, 0x46, 0x6f, 0x72, 0x63, 0x65, 0x51, 0x75, 0x65, 0x72, 0x79, 0x12, 0x55,
 	0x0a, 0x13, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x5f, 0x73, 0x65, 0x74,
 	0x74, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x15, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x78, 0x72,

+ 1 - 1
transport/internet/tls/config.proto

@@ -98,7 +98,7 @@ message Config {
 
   string ech_config_list = 19;
 
-  bool ech_force_query = 20;
+  string ech_force_query = 20;
 
   SocketConfig ech_socket_settings = 21;
 }

+ 36 - 18
transport/internet/tls/ech.go

@@ -9,10 +9,6 @@ import (
 	"encoding/base64"
 	"encoding/binary"
 	"fmt"
-	utls "github.com/refraction-networking/utls"
-	"github.com/xtls/xray-core/common/crypto"
-	dns2 "github.com/xtls/xray-core/features/dns"
-	"golang.org/x/net/http2"
 	"io"
 	"net/http"
 	"net/url"
@@ -21,6 +17,11 @@ import (
 	"sync/atomic"
 	"time"
 
+	utls "github.com/refraction-networking/utls"
+	"github.com/xtls/xray-core/common/crypto"
+	dns2 "github.com/xtls/xray-core/features/dns"
+	"golang.org/x/net/http2"
+
 	"github.com/miekg/dns"
 	"github.com/xtls/reality"
 	"github.com/xtls/reality/hpke"
@@ -52,10 +53,18 @@ func ApplyECH(c *Config, config *tls.Config) error {
 
 	// for client
 	if len(c.EchConfigList) != 0 {
+		ECHForceQuery := c.EchForceQuery
+		switch ECHForceQuery {
+		case "none", "half", "full":
+		case "":
+			ECHForceQuery = "none" // default to none
+		default:
+			panic("Invalid ECHForceQuery: " + c.EchForceQuery)
+		}
 		defer func() {
 			// if failed to get ECHConfig, use an invalid one to make connection fail
-			if err != nil {
-				if c.EchForceQuery {
+			if err != nil || len(ECHConfig) == 0 {
+				if ECHForceQuery == "full" {
 					ECHConfig = []byte{1, 1, 4, 5, 1, 4}
 				}
 			}
@@ -106,32 +115,40 @@ type echConfigRecord struct {
 }
 
 var (
-	// key value must be like this: "example.com|udp://1.1.1.1"
+	// The keys for both maps must be generated by ECHCacheKey().
 	GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
 	clientForECHDOH      = utils.NewTypedSyncMap[string, *http.Client]()
 )
 
+// sockopt can be nil if not specified.
+// if for clientForECHDOH, domain can be empty.
+func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string {
+	return server + "|" + domain + "|" + fmt.Sprintf("%p", sockopt)
+}
+
 // Update updates the ECH config for given domain and server.
 // this method is concurrent safe, only one update request will be sent, others get the cache.
 // if isLockedUpdate is true, it will not try to acquire the lock.
-func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) {
+func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
 	if !isLockedUpdate {
 		c.UpdateLock.Lock()
 		defer c.UpdateLock.Unlock()
 	}
 	// Double check cache after acquiring lock
 	configRecord := c.configRecord.Load()
-	if configRecord.expire.After(time.Now()) {
+	if configRecord.expire.After(time.Now()) && configRecord.err == nil {
 		errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
 		return configRecord.config, configRecord.err
 	}
 	// Query ECH config from DNS server
 	errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
 	echConfig, ttl, err := dnsQuery(server, domain, sockopt)
-	if err != nil {
-		if forceQuery || ttl == 0 {
-			return nil, err
-		}
+	// if in "full", directly return
+	if err != nil && forceQuery == "full" {
+		return nil, err
+	}
+	if ttl == 0 {
+		ttl = dns2.DefaultTTL
 	}
 	configRecord = &echConfigRecord{
 		config: echConfig,
@@ -144,8 +161,8 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
 
 // QueryRecord returns the ECH config for given domain.
 // If the record is not in cache or expired, it will query the DNS server and update the cache.
-func QueryRecord(domain string, server string, forceQuery bool, sockopt *internet.SocketConfig) ([]byte, error) {
-	GlobalECHConfigCacheKey := domain + "|" + server + "|" + fmt.Sprintf("%p", sockopt)
+func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
+	GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt)
 	echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
 	if !ok {
 		echConfigCache = &ECHConfigCache{}
@@ -153,7 +170,7 @@ func QueryRecord(domain string, server string, forceQuery bool, sockopt *interne
 		echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
 	}
 	configRecord := echConfigCache.configRecord.Load()
-	if configRecord.expire.After(time.Now()) {
+	if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") {
 		errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
 		return configRecord.config, configRecord.err
 	}
@@ -196,7 +213,7 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b
 			return nil, 0, err
 		}
 		var client *http.Client
-		serverKey := server + "|" + fmt.Sprintf("%p", sockopt)
+		serverKey := ECHCacheKey(server, "", sockopt)
 		if client, _ = clientForECHDOH.Load(serverKey); client == nil {
 			// All traffic sent by core should via xray's internet.DialSystem
 			// This involves the behavior of some Android VPN GUI clients
@@ -307,7 +324,8 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b
 			}
 		}
 	}
-	return nil, dns2.DefaultTTL, dns2.ErrEmptyResponse
+	// empty is valid, means no ECH config found
+	return nil, dns2.DefaultTTL, nil
 }
 
 // reference github.com/OmarTariq612/goech

+ 5 - 16
transport/internet/tls/ech_test.go

@@ -1,7 +1,6 @@
 package tls
 
 import (
-	"fmt"
 	"io"
 	"net/http"
 	"strings"
@@ -41,7 +40,7 @@ func TestECHDial(t *testing.T) {
 	}
 	wg.Wait()
 	// check cache
-	echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings))
+	echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://1.1.1.1", "encryptedsni.com", nil))
 	if !ok {
 		t.Error("ECH config cache not found")
 
@@ -60,22 +59,12 @@ func TestECHDial(t *testing.T) {
 func TestECHDialFail(t *testing.T) {
 	config := &Config{
 		ServerName:    "cloudflare.com",
-		EchConfigList: "udp://1.1.1.1",
+		EchConfigList: "udp://127.0.0.1",
+		EchForceQuery: "half",
 	}
-	TLSConfig := config.GetTLSConfig()
-	TLSConfig.NextProtos = []string{"http/1.1"}
-	client := &http.Client{
-		Transport: &http.Transport{
-			TLSClientConfig: TLSConfig,
-		},
-	}
-	resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
-	common.Must(err)
-	defer resp.Body.Close()
-	_, err = io.ReadAll(resp.Body)
-	common.Must(err)
+	config.GetTLSConfig()
 	// check cache
-	echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1" + "|" + fmt.Sprintf("%p", config.EchSocketSettings))
+	echConfigCache, ok := GlobalECHConfigCache.Load(ECHCacheKey("udp://127.0.0.1", "cloudflare.com", nil))
 	if !ok {
 		t.Error("ECH config cache not found")
 	}