ech.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. package tls
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/ecdh"
  6. "crypto/rand"
  7. "crypto/tls"
  8. "encoding/base64"
  9. "encoding/binary"
  10. "fmt"
  11. "io"
  12. "net/http"
  13. "net/url"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. utls "github.com/refraction-networking/utls"
  19. "github.com/xtls/xray-core/common/crypto"
  20. dns2 "github.com/xtls/xray-core/features/dns"
  21. "golang.org/x/net/http2"
  22. "github.com/miekg/dns"
  23. "github.com/xtls/reality"
  24. "github.com/xtls/reality/hpke"
  25. "github.com/xtls/xray-core/common/errors"
  26. "github.com/xtls/xray-core/common/net"
  27. "github.com/xtls/xray-core/common/utils"
  28. "github.com/xtls/xray-core/transport/internet"
  29. "golang.org/x/crypto/cryptobyte"
  30. )
  31. func ApplyECH(c *Config, config *tls.Config) error {
  32. var ECHConfig []byte
  33. var err error
  34. var nameToQuery string
  35. if net.ParseAddress(config.ServerName).Family().IsDomain() {
  36. nameToQuery = config.ServerName
  37. }
  38. var DNSServer string
  39. // for server
  40. if len(c.EchServerKeys) != 0 {
  41. KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
  42. if err != nil {
  43. return errors.New("Failed to unmarshal ECHKeySetList: ", err)
  44. }
  45. config.EncryptedClientHelloKeys = KeySets
  46. }
  47. // for client
  48. if len(c.EchConfigList) != 0 {
  49. ECHForceQuery := c.EchForceQuery
  50. switch ECHForceQuery {
  51. case "none", "half", "full":
  52. case "":
  53. ECHForceQuery = "none" // default to none
  54. default:
  55. panic("Invalid ECHForceQuery: " + c.EchForceQuery)
  56. }
  57. defer func() {
  58. // if failed to get ECHConfig, use an invalid one to make connection fail
  59. if err != nil || len(ECHConfig) == 0 {
  60. if ECHForceQuery == "full" {
  61. ECHConfig = []byte{1, 1, 4, 5, 1, 4}
  62. }
  63. }
  64. config.EncryptedClientHelloConfigList = ECHConfig
  65. }()
  66. // direct base64 config
  67. if strings.Contains(c.EchConfigList, "://") {
  68. // query config from dns
  69. parts := strings.Split(c.EchConfigList, "+")
  70. if len(parts) == 2 {
  71. // parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query"
  72. nameToQuery = parts[0]
  73. DNSServer = parts[1]
  74. } else if len(parts) == 1 {
  75. // normal format
  76. DNSServer = parts[0]
  77. } else {
  78. return errors.New("Invalid ECH DNS server format: ", c.EchConfigList)
  79. }
  80. if nameToQuery == "" {
  81. return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
  82. }
  83. ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery, c.EchSocketSettings)
  84. if err != nil {
  85. return errors.New("Failed to query ECH DNS record for domain: ", nameToQuery, " at server: ", DNSServer).Base(err)
  86. }
  87. } else {
  88. ECHConfig, err = base64.StdEncoding.DecodeString(c.EchConfigList)
  89. if err != nil {
  90. return errors.New("Failed to unmarshal ECHConfigList: ", err)
  91. }
  92. }
  93. }
  94. return nil
  95. }
  96. type ECHConfigCache struct {
  97. configRecord atomic.Pointer[echConfigRecord]
  98. // updateLock is not for preventing concurrent read/write, but for preventing concurrent update
  99. UpdateLock sync.Mutex
  100. }
  101. type echConfigRecord struct {
  102. config []byte
  103. expire time.Time
  104. err error
  105. }
  106. var (
  107. // The keys for both maps must be generated by ECHCacheKey().
  108. GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
  109. clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
  110. )
  111. // sockopt can be nil if not specified.
  112. // if for clientForECHDOH, domain can be empty.
  113. func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string {
  114. return server + "|" + domain + "|" + fmt.Sprintf("%p", sockopt)
  115. }
  116. // Update updates the ECH config for given domain and server.
  117. // this method is concurrent safe, only one update request will be sent, others get the cache.
  118. // if isLockedUpdate is true, it will not try to acquire the lock.
  119. func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
  120. if !isLockedUpdate {
  121. c.UpdateLock.Lock()
  122. defer c.UpdateLock.Unlock()
  123. }
  124. // Double check cache after acquiring lock
  125. configRecord := c.configRecord.Load()
  126. if configRecord.expire.After(time.Now()) && configRecord.err == nil {
  127. errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
  128. return configRecord.config, configRecord.err
  129. }
  130. // Query ECH config from DNS server
  131. errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
  132. echConfig, ttl, err := dnsQuery(server, domain, sockopt)
  133. // if in "full", directly return
  134. if err != nil && forceQuery == "full" {
  135. return nil, err
  136. }
  137. if ttl == 0 {
  138. ttl = dns2.DefaultTTL
  139. }
  140. configRecord = &echConfigRecord{
  141. config: echConfig,
  142. expire: time.Now().Add(time.Duration(ttl) * time.Second),
  143. err: err,
  144. }
  145. c.configRecord.Store(configRecord)
  146. return configRecord.config, configRecord.err
  147. }
  148. // QueryRecord returns the ECH config for given domain.
  149. // If the record is not in cache or expired, it will query the DNS server and update the cache.
  150. func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
  151. GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt)
  152. echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
  153. if !ok {
  154. echConfigCache = &ECHConfigCache{}
  155. echConfigCache.configRecord.Store(&echConfigRecord{})
  156. echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
  157. }
  158. configRecord := echConfigCache.configRecord.Load()
  159. if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") {
  160. errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
  161. return configRecord.config, configRecord.err
  162. }
  163. // If expire is zero value, it means we are in initial state, wait for the query to finish
  164. // otherwise return old value immediately and update in a goroutine
  165. // but if the cache is too old, wait for update
  166. if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
  167. return echConfigCache.Update(domain, server, false, forceQuery, sockopt)
  168. } else {
  169. // If someone already acquired the lock, it means it is updating, do not start another update goroutine
  170. if echConfigCache.UpdateLock.TryLock() {
  171. go func() {
  172. defer echConfigCache.UpdateLock.Unlock()
  173. echConfigCache.Update(domain, server, true, forceQuery, sockopt)
  174. }()
  175. }
  176. return configRecord.config, configRecord.err
  177. }
  178. }
  179. // dnsQuery is the real func for sending type65 query for given domain to given DNS server.
  180. // return ECH config, TTL and error
  181. func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]byte, uint32, error) {
  182. m := new(dns.Msg)
  183. var dnsResolve []byte
  184. m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS)
  185. // for DOH server
  186. if strings.HasPrefix(server, "https://") || strings.HasPrefix(server, "h2c://") {
  187. h2c := strings.HasPrefix(server, "h2c://")
  188. m.SetEdns0(4096, false) // 4096 is the buffer size, false means no DNSSEC
  189. padding := &dns.EDNS0_PADDING{Padding: make([]byte, int(crypto.RandBetween(100, 300)))}
  190. if opt := m.IsEdns0(); opt != nil {
  191. opt.Option = append(opt.Option, padding)
  192. }
  193. // always 0 in DOH
  194. m.Id = 0
  195. msg, err := m.Pack()
  196. if err != nil {
  197. return nil, 0, err
  198. }
  199. var client *http.Client
  200. serverKey := ECHCacheKey(server, "", sockopt)
  201. if client, _ = clientForECHDOH.Load(serverKey); client == nil {
  202. // All traffic sent by core should via xray's internet.DialSystem
  203. // This involves the behavior of some Android VPN GUI clients
  204. tr := &http2.Transport{
  205. IdleConnTimeout: net.ConnIdleTimeout,
  206. ReadIdleTimeout: net.ChromeH2KeepAlivePeriod,
  207. DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
  208. dest, err := net.ParseDestination(network + ":" + addr)
  209. if err != nil {
  210. return nil, err
  211. }
  212. var conn net.Conn
  213. conn, err = internet.DialSystem(ctx, dest, sockopt)
  214. if err != nil {
  215. return nil, err
  216. }
  217. if !h2c {
  218. u, err := url.Parse(server)
  219. if err != nil {
  220. return nil, err
  221. }
  222. conn = utls.UClient(conn, &utls.Config{ServerName: u.Hostname()}, utls.HelloChrome_Auto)
  223. if err := conn.(*utls.UConn).HandshakeContext(ctx); err != nil {
  224. return nil, err
  225. }
  226. }
  227. return conn, nil
  228. },
  229. }
  230. c := &http.Client{
  231. Timeout: 5 * time.Second,
  232. Transport: tr,
  233. }
  234. client, _ = clientForECHDOH.LoadOrStore(serverKey, c)
  235. }
  236. req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
  237. if err != nil {
  238. return nil, 0, err
  239. }
  240. req.Header.Set("Accept", "application/dns-message")
  241. req.Header.Set("Content-Type", "application/dns-message")
  242. req.Header.Set("X-Padding", strings.Repeat("X", int(crypto.RandBetween(100, 1000))))
  243. resp, err := client.Do(req)
  244. if err != nil {
  245. return nil, 0, err
  246. }
  247. defer resp.Body.Close()
  248. respBody, err := io.ReadAll(resp.Body)
  249. if err != nil {
  250. return nil, 0, err
  251. }
  252. if resp.StatusCode != http.StatusOK {
  253. return nil, 0, errors.New("query failed with response code:", resp.StatusCode)
  254. }
  255. dnsResolve = respBody
  256. } else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
  257. udpServerAddr := server[len("udp://"):]
  258. // default port 53 if not specified
  259. if !strings.Contains(udpServerAddr, ":") {
  260. udpServerAddr = udpServerAddr + ":53"
  261. }
  262. dest, err := net.ParseDestination("udp" + ":" + udpServerAddr)
  263. if err != nil {
  264. return nil, 0, errors.New("failed to parse udp dns server ", udpServerAddr, " for ECH: ", err)
  265. }
  266. dnsTimeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  267. defer cancel()
  268. // use xray's internet.DialSystem as mentioned above
  269. conn, err := internet.DialSystem(dnsTimeoutCtx, dest, sockopt)
  270. if err != nil {
  271. return nil, 0, err
  272. }
  273. defer func() {
  274. err := conn.Close()
  275. if err != nil {
  276. errors.LogDebug(context.Background(), "Failed to close connection: ", err)
  277. }
  278. }()
  279. msg, err := m.Pack()
  280. if err != nil {
  281. return nil, 0, err
  282. }
  283. conn.Write(msg)
  284. udpResponse := make([]byte, 512)
  285. conn.SetReadDeadline(time.Now().Add(5 * time.Second))
  286. _, err = conn.Read(udpResponse)
  287. if err != nil {
  288. return nil, 0, err
  289. }
  290. dnsResolve = udpResponse
  291. }
  292. respMsg := new(dns.Msg)
  293. err := respMsg.Unpack(dnsResolve)
  294. if err != nil {
  295. return nil, 0, errors.New("failed to unpack dns response for ECH: ", err)
  296. }
  297. if len(respMsg.Answer) > 0 {
  298. for _, answer := range respMsg.Answer {
  299. if https, ok := answer.(*dns.HTTPS); ok && https.Hdr.Name == dns.Fqdn(domain) {
  300. for _, v := range https.Value {
  301. if echConfig, ok := v.(*dns.SVCBECHConfig); ok {
  302. errors.LogDebug(context.Background(), "Get ECH config:", echConfig.String(), " TTL:", respMsg.Answer[0].Header().Ttl)
  303. return echConfig.ECH, answer.Header().Ttl, nil
  304. }
  305. }
  306. }
  307. }
  308. }
  309. // empty is valid, means no ECH config found
  310. return nil, dns2.DefaultTTL, nil
  311. }
  312. // reference github.com/OmarTariq612/goech
  313. func MarshalBinary(ech reality.EchConfig) ([]byte, error) {
  314. var b cryptobyte.Builder
  315. b.AddUint16(ech.Version)
  316. b.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  317. child.AddUint8(ech.ConfigID)
  318. child.AddUint16(ech.KemID)
  319. child.AddUint16(uint16(len(ech.PublicKey)))
  320. child.AddBytes(ech.PublicKey)
  321. child.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  322. for _, cipherSuite := range ech.SymmetricCipherSuite {
  323. child.AddUint16(cipherSuite.KDFID)
  324. child.AddUint16(cipherSuite.AEADID)
  325. }
  326. })
  327. child.AddUint8(ech.MaxNameLength)
  328. child.AddUint8(uint8(len(ech.PublicName)))
  329. child.AddBytes(ech.PublicName)
  330. child.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  331. for _, extention := range ech.Extensions {
  332. child.AddUint16(extention.Type)
  333. child.AddBytes(extention.Data)
  334. }
  335. })
  336. })
  337. return b.Bytes()
  338. }
  339. var ErrInvalidLen = errors.New("goech: invalid length")
  340. func ConvertToGoECHKeys(data []byte) ([]tls.EncryptedClientHelloKey, error) {
  341. var keys []tls.EncryptedClientHelloKey
  342. s := cryptobyte.String(data)
  343. for !s.Empty() {
  344. if len(s) < 2 {
  345. return keys, ErrInvalidLen
  346. }
  347. keyLength := int(binary.BigEndian.Uint16(s[:2]))
  348. if len(s) < keyLength+4 {
  349. return keys, ErrInvalidLen
  350. }
  351. configLength := int(binary.BigEndian.Uint16(s[keyLength+2 : keyLength+4]))
  352. if len(s) < 2+keyLength+2+configLength {
  353. return keys, ErrInvalidLen
  354. }
  355. child := cryptobyte.String(s[:2+keyLength+2+configLength])
  356. var (
  357. sk, config cryptobyte.String
  358. )
  359. if !child.ReadUint16LengthPrefixed(&sk) || !child.ReadUint16LengthPrefixed(&config) || !child.Empty() {
  360. return keys, ErrInvalidLen
  361. }
  362. if !s.Skip(2 + keyLength + 2 + configLength) {
  363. return keys, ErrInvalidLen
  364. }
  365. keys = append(keys, tls.EncryptedClientHelloKey{
  366. Config: config,
  367. PrivateKey: sk,
  368. })
  369. }
  370. return keys, nil
  371. }
  372. const ExtensionEncryptedClientHello = 0xfe0d
  373. const KDF_HKDF_SHA384 = 0x0002
  374. const KDF_HKDF_SHA512 = 0x0003
  375. func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchConfig, []byte, error) {
  376. config := reality.EchConfig{
  377. Version: ExtensionEncryptedClientHello,
  378. ConfigID: configID,
  379. PublicName: []byte(domain),
  380. KemID: kem,
  381. SymmetricCipherSuite: []reality.EchCipher{
  382. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_128_GCM},
  383. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_256_GCM},
  384. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_ChaCha20Poly1305},
  385. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_AES_128_GCM},
  386. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_AES_256_GCM},
  387. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_ChaCha20Poly1305},
  388. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_128_GCM},
  389. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_256_GCM},
  390. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_ChaCha20Poly1305},
  391. },
  392. MaxNameLength: 0,
  393. Extensions: nil,
  394. }
  395. // if kem == hpke.DHKEM_X25519_HKDF_SHA256 {
  396. curve := ecdh.X25519()
  397. priv := make([]byte, 32) //x25519
  398. _, err := io.ReadFull(rand.Reader, priv)
  399. if err != nil {
  400. return config, nil, err
  401. }
  402. privKey, _ := curve.NewPrivateKey(priv)
  403. config.PublicKey = privKey.PublicKey().Bytes()
  404. return config, priv, nil
  405. // }
  406. // TODO: add mlkem768 (former kyber768 draft00). The golang mlkem private key is 64 bytes seed?
  407. }