ech.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. package tls
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/ecdh"
  6. "crypto/rand"
  7. "crypto/tls"
  8. "encoding/base64"
  9. "encoding/binary"
  10. "io"
  11. "net/http"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/miekg/dns"
  17. "github.com/xtls/reality"
  18. "github.com/xtls/reality/hpke"
  19. "github.com/xtls/xray-core/common/errors"
  20. "github.com/xtls/xray-core/common/net"
  21. "github.com/xtls/xray-core/common/utils"
  22. "github.com/xtls/xray-core/transport/internet"
  23. "golang.org/x/crypto/cryptobyte"
  24. )
  25. func ApplyECH(c *Config, config *tls.Config) error {
  26. var ECHConfig []byte
  27. var err error
  28. nameToQuery := c.ServerName
  29. var DNSServer string
  30. // for client
  31. if len(c.EchConfigList) != 0 {
  32. // direct base64 config
  33. if strings.Contains(c.EchConfigList, "://") {
  34. // query config from dns
  35. parts := strings.Split(c.EchConfigList, "+")
  36. if len(parts) == 2 {
  37. // parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query"
  38. nameToQuery = parts[0]
  39. DNSServer = parts[1]
  40. } else if len(parts) == 1 {
  41. // normal format
  42. DNSServer = parts[0]
  43. } else {
  44. return errors.New("Invalid ECH DNS server format: ", c.EchConfigList)
  45. }
  46. if nameToQuery == "" {
  47. return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
  48. }
  49. ECHConfig, err = QueryRecord(nameToQuery, DNSServer)
  50. if err != nil {
  51. return err
  52. }
  53. } else {
  54. ECHConfig, err = base64.StdEncoding.DecodeString(c.EchConfigList)
  55. if err != nil {
  56. return errors.New("Failed to unmarshal ECHConfigList: ", err)
  57. }
  58. }
  59. config.EncryptedClientHelloConfigList = ECHConfig
  60. }
  61. // for server
  62. if len(c.EchServerKeys) != 0 {
  63. KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
  64. if err != nil {
  65. return errors.New("Failed to unmarshal ECHKeySetList: ", err)
  66. }
  67. config.EncryptedClientHelloKeys = KeySets
  68. }
  69. return nil
  70. }
  71. type ECHConfigCache struct {
  72. configRecord atomic.Pointer[echConfigRecord]
  73. // updateLock is not for preventing concurrent read/write, but for preventing concurrent update
  74. UpdateLock sync.Mutex
  75. }
  76. type echConfigRecord struct {
  77. config []byte
  78. expire time.Time
  79. }
  80. var (
  81. GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
  82. clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
  83. )
  84. // Update updates the ECH config for given domain and server.
  85. // this method is concurrent safe, only one update request will be sent, others get the cache.
  86. // if isLockedUpdate is true, it will not try to acquire the lock.
  87. func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) {
  88. if !isLockedUpdate {
  89. c.UpdateLock.Lock()
  90. defer c.UpdateLock.Unlock()
  91. }
  92. // Double check cache after acquiring lock
  93. configRecord := c.configRecord.Load()
  94. if configRecord.expire.After(time.Now()) {
  95. errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
  96. return configRecord.config, nil
  97. }
  98. // Query ECH config from DNS server
  99. errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
  100. echConfig, ttl, err := dnsQuery(server, domain)
  101. if err != nil {
  102. return nil, err
  103. }
  104. configRecord = &echConfigRecord{
  105. config: echConfig,
  106. expire: time.Now().Add(time.Duration(ttl) * time.Second),
  107. }
  108. c.configRecord.Store(configRecord)
  109. return configRecord.config, nil
  110. }
  111. // QueryRecord returns the ECH config for given domain.
  112. // If the record is not in cache or expired, it will query the DNS server and update the cache.
  113. func QueryRecord(domain string, server string) ([]byte, error) {
  114. echConfigCache, ok := GlobalECHConfigCache.Load(domain)
  115. if !ok {
  116. echConfigCache = &ECHConfigCache{}
  117. echConfigCache.configRecord.Store(&echConfigRecord{})
  118. echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache)
  119. }
  120. configRecord := echConfigCache.configRecord.Load()
  121. if configRecord.expire.After(time.Now()) {
  122. errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
  123. return configRecord.config, nil
  124. }
  125. // If expire is zero value, it means we are in initial state, wait for the query to finish
  126. // otherwise return old value immediately and update in a goroutine
  127. // but if the cache is too old, wait for update
  128. if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
  129. return echConfigCache.Update(domain, server, false)
  130. } else {
  131. // If someone already acquired the lock, it means it is updating, do not start another update goroutine
  132. if echConfigCache.UpdateLock.TryLock() {
  133. go func() {
  134. defer echConfigCache.UpdateLock.Unlock()
  135. echConfigCache.Update(domain, server, true)
  136. }()
  137. }
  138. return configRecord.config, nil
  139. }
  140. }
  141. // dnsQuery is the real func for sending type65 query for given domain to given DNS server.
  142. // return ECH config, TTL and error
  143. func dnsQuery(server string, domain string) ([]byte, uint32, error) {
  144. m := new(dns.Msg)
  145. var dnsResolve []byte
  146. m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS)
  147. // for DOH server
  148. if strings.HasPrefix(server, "https://") {
  149. // always 0 in DOH
  150. m.Id = 0
  151. msg, err := m.Pack()
  152. if err != nil {
  153. return []byte{}, 0, err
  154. }
  155. var client *http.Client
  156. if client, _ = clientForECHDOH.Load(server); client == nil {
  157. // All traffic sent by core should via xray's internet.DialSystem
  158. // This involves the behavior of some Android VPN GUI clients
  159. tr := &http.Transport{
  160. IdleConnTimeout: 90 * time.Second,
  161. ForceAttemptHTTP2: true,
  162. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  163. dest, err := net.ParseDestination(network + ":" + addr)
  164. if err != nil {
  165. return nil, err
  166. }
  167. conn, err := internet.DialSystem(ctx, dest, nil)
  168. if err != nil {
  169. return nil, err
  170. }
  171. return conn, nil
  172. },
  173. }
  174. c := &http.Client{
  175. Timeout: 5 * time.Second,
  176. Transport: tr,
  177. }
  178. client, _ = clientForECHDOH.LoadOrStore(server, c)
  179. }
  180. req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
  181. if err != nil {
  182. return []byte{}, 0, err
  183. }
  184. req.Header.Set("Content-Type", "application/dns-message")
  185. resp, err := client.Do(req)
  186. if err != nil {
  187. return []byte{}, 0, err
  188. }
  189. defer resp.Body.Close()
  190. respBody, err := io.ReadAll(resp.Body)
  191. if err != nil {
  192. return []byte{}, 0, err
  193. }
  194. if resp.StatusCode != http.StatusOK {
  195. return []byte{}, 0, errors.New("query failed with response code:", resp.StatusCode)
  196. }
  197. dnsResolve = respBody
  198. } else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
  199. udpServerAddr := server[len("udp://"):]
  200. // default port 53 if not specified
  201. if !strings.Contains(udpServerAddr, ":") {
  202. udpServerAddr = udpServerAddr + ":53"
  203. }
  204. dest, err := net.ParseDestination("udp" + ":" + udpServerAddr)
  205. if err != nil {
  206. return nil, 0, errors.New("failed to parse udp dns server ", udpServerAddr, " for ECH: ", err)
  207. }
  208. dnsTimeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  209. defer cancel()
  210. // use xray's internet.DialSystem as mentioned above
  211. conn, err := internet.DialSystem(dnsTimeoutCtx, dest, nil)
  212. defer func() {
  213. err := conn.Close()
  214. if err != nil {
  215. errors.LogDebug(context.Background(), "Failed to close connection: ", err)
  216. }
  217. }()
  218. if err != nil {
  219. return []byte{}, 0, err
  220. }
  221. msg, err := m.Pack()
  222. if err != nil {
  223. return []byte{}, 0, err
  224. }
  225. conn.Write(msg)
  226. udpResponse := make([]byte, 512)
  227. _, err = conn.Read(udpResponse)
  228. if err != nil {
  229. return []byte{}, 0, err
  230. }
  231. dnsResolve = udpResponse
  232. }
  233. respMsg := new(dns.Msg)
  234. err := respMsg.Unpack(dnsResolve)
  235. if err != nil {
  236. return []byte{}, 0, errors.New("failed to unpack dns response for ECH: ", err)
  237. }
  238. if len(respMsg.Answer) > 0 {
  239. for _, answer := range respMsg.Answer {
  240. if https, ok := answer.(*dns.HTTPS); ok && https.Hdr.Name == dns.Fqdn(domain) {
  241. for _, v := range https.Value {
  242. if echConfig, ok := v.(*dns.SVCBECHConfig); ok {
  243. errors.LogDebug(context.Background(), "Get ECH config:", echConfig.String(), " TTL:", respMsg.Answer[0].Header().Ttl)
  244. return echConfig.ECH, answer.Header().Ttl, nil
  245. }
  246. }
  247. }
  248. }
  249. }
  250. return []byte{}, 0, errors.New("no ech record found")
  251. }
  252. // reference github.com/OmarTariq612/goech
  253. func MarshalBinary(ech reality.EchConfig) ([]byte, error) {
  254. var b cryptobyte.Builder
  255. b.AddUint16(ech.Version)
  256. b.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  257. child.AddUint8(ech.ConfigID)
  258. child.AddUint16(ech.KemID)
  259. child.AddUint16(uint16(len(ech.PublicKey)))
  260. child.AddBytes(ech.PublicKey)
  261. child.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  262. for _, cipherSuite := range ech.SymmetricCipherSuite {
  263. child.AddUint16(cipherSuite.KDFID)
  264. child.AddUint16(cipherSuite.AEADID)
  265. }
  266. })
  267. child.AddUint8(ech.MaxNameLength)
  268. child.AddUint8(uint8(len(ech.PublicName)))
  269. child.AddBytes(ech.PublicName)
  270. child.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
  271. for _, extention := range ech.Extensions {
  272. child.AddUint16(extention.Type)
  273. child.AddBytes(extention.Data)
  274. }
  275. })
  276. })
  277. return b.Bytes()
  278. }
  279. var ErrInvalidLen = errors.New("goech: invalid length")
  280. func ConvertToGoECHKeys(data []byte) ([]tls.EncryptedClientHelloKey, error) {
  281. var keys []tls.EncryptedClientHelloKey
  282. s := cryptobyte.String(data)
  283. for !s.Empty() {
  284. if len(s) < 2 {
  285. return keys, ErrInvalidLen
  286. }
  287. keyLength := int(binary.BigEndian.Uint16(s[:2]))
  288. if len(s) < keyLength+4 {
  289. return keys, ErrInvalidLen
  290. }
  291. configLength := int(binary.BigEndian.Uint16(s[keyLength+2 : keyLength+4]))
  292. if len(s) < 2+keyLength+2+configLength {
  293. return keys, ErrInvalidLen
  294. }
  295. child := cryptobyte.String(s[:2+keyLength+2+configLength])
  296. var (
  297. sk, config cryptobyte.String
  298. )
  299. if !child.ReadUint16LengthPrefixed(&sk) || !child.ReadUint16LengthPrefixed(&config) || !child.Empty() {
  300. return keys, ErrInvalidLen
  301. }
  302. if !s.Skip(2 + keyLength + 2 + configLength) {
  303. return keys, ErrInvalidLen
  304. }
  305. keys = append(keys, tls.EncryptedClientHelloKey{
  306. Config: config,
  307. PrivateKey: sk,
  308. })
  309. }
  310. return keys, nil
  311. }
  312. const ExtensionEncryptedClientHello = 0xfe0d
  313. const KDF_HKDF_SHA384 = 0x0002
  314. const KDF_HKDF_SHA512 = 0x0003
  315. func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchConfig, []byte, error) {
  316. config := reality.EchConfig{
  317. Version: ExtensionEncryptedClientHello,
  318. ConfigID: configID,
  319. PublicName: []byte(domain),
  320. KemID: kem,
  321. SymmetricCipherSuite: []reality.EchCipher{
  322. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_128_GCM},
  323. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_256_GCM},
  324. {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_ChaCha20Poly1305},
  325. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_AES_128_GCM},
  326. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_AES_256_GCM},
  327. {KDFID: KDF_HKDF_SHA384, AEADID: hpke.AEAD_ChaCha20Poly1305},
  328. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_128_GCM},
  329. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_256_GCM},
  330. {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_ChaCha20Poly1305},
  331. },
  332. MaxNameLength: 0,
  333. Extensions: nil,
  334. }
  335. // if kem == hpke.DHKEM_X25519_HKDF_SHA256 {
  336. curve := ecdh.X25519()
  337. priv := make([]byte, 32) //x25519
  338. _, err := io.ReadFull(rand.Reader, priv)
  339. if err != nil {
  340. return config, nil, err
  341. }
  342. privKey, _ := curve.NewPrivateKey(priv)
  343. config.PublicKey = privKey.PublicKey().Bytes()
  344. return config, priv, nil
  345. // }
  346. // TODO: add mlkem768 (former kyber768 draft00). The golang mlkem private key is 64 bytes seed?
  347. }