소스 검색

Fix fetch ECH configs

世界 6 달 전
부모
커밋
1f3097da00
1개의 변경된 파일29개의 추가작업 그리고 13개의 파일을 삭제
  1. 29 13
      common/tls/ech.go

+ 29 - 13
common/tls/ech.go

@@ -10,6 +10,8 @@ import (
 	"net"
 	"net"
 	"os"
 	"os"
 	"strings"
 	"strings"
+	"sync"
+	"time"
 
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/dns"
 	"github.com/sagernet/sing-box/dns"
@@ -46,7 +48,10 @@ func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions
 		tlsConfig.EncryptedClientHelloConfigList = block.Bytes
 		tlsConfig.EncryptedClientHelloConfigList = block.Bytes
 		return &STDClientConfig{tlsConfig}, nil
 		return &STDClientConfig{tlsConfig}, nil
 	} else {
 	} else {
-		return &STDECHClientConfig{STDClientConfig{tlsConfig}, service.FromContext[adapter.DNSRouter](ctx)}, nil
+		return &STDECHClientConfig{
+			STDClientConfig: STDClientConfig{tlsConfig},
+			dnsRouter:       service.FromContext[adapter.DNSRouter](ctx),
+		}, nil
 	}
 	}
 }
 }
 
 
@@ -99,11 +104,28 @@ func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
 
 
 type STDECHClientConfig struct {
 type STDECHClientConfig struct {
 	STDClientConfig
 	STDClientConfig
-	dnsRouter adapter.DNSRouter
+	access     sync.Mutex
+	dnsRouter  adapter.DNSRouter
+	lastTTL    time.Duration
+	lastUpdate time.Time
 }
 }
 
 
 func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
 func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
-	if len(s.config.EncryptedClientHelloConfigList) == 0 {
+	tlsConn, err := s.fetchAndHandshake(ctx, conn)
+	if err != nil {
+		return nil, err
+	}
+	err = tlsConn.HandshakeContext(ctx)
+	if err != nil {
+		return nil, err
+	}
+	return tlsConn, nil
+}
+
+func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
+	s.access.Lock()
+	defer s.access.Unlock()
+	if len(s.config.EncryptedClientHelloConfigList) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL {
 		message := &mDNS.Msg{
 		message := &mDNS.Msg{
 			MsgHdr: mDNS.MsgHdr{
 			MsgHdr: mDNS.MsgHdr{
 				RecursionDesired: true,
 				RecursionDesired: true,
@@ -133,6 +155,8 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn)
 						if err != nil {
 						if err != nil {
 							return nil, E.Cause(err, "decode ECH config")
 							return nil, E.Cause(err, "decode ECH config")
 						}
 						}
+						s.lastTTL = time.Duration(rr.Header().Ttl) * time.Second
+						s.lastUpdate = time.Now()
 						s.config.EncryptedClientHelloConfigList = echConfigList
 						s.config.EncryptedClientHelloConfigList = echConfigList
 						break match
 						break match
 					}
 					}
@@ -143,19 +167,11 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn)
 			return nil, E.New("no ECH config found in DNS records")
 			return nil, E.New("no ECH config found in DNS records")
 		}
 		}
 	}
 	}
-	tlsConn, err := s.Client(conn)
-	if err != nil {
-		return nil, err
-	}
-	err = tlsConn.HandshakeContext(ctx)
-	if err != nil {
-		return nil, err
-	}
-	return tlsConn, nil
+	return s.Client(conn)
 }
 }
 
 
 func (s *STDECHClientConfig) Clone() Config {
 func (s *STDECHClientConfig) Clone() Config {
-	return &STDECHClientConfig{STDClientConfig{s.config.Clone()}, s.dnsRouter}
+	return &STDECHClientConfig{STDClientConfig: STDClientConfig{s.config.Clone()}, dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate}
 }
 }
 
 
 func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {
 func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {