瀏覽代碼

Fix DNS exchange

世界 1 月之前
父節點
當前提交
e81a76fdf9
共有 7 個文件被更改,包括 58 次插入27 次删除
  1. 6 0
      common/certificate/store.go
  2. 5 14
      common/tls/ech.go
  3. 43 10
      common/tls/std_server.go
  4. 1 1
      dns/client.go
  5. 1 0
      dns/rcode.go
  6. 1 1
      dns/transport/dhcp/dhcp_shared.go
  7. 1 1
      dns/transport/local/local.go

+ 6 - 0
common/certificate/store.go

@@ -7,6 +7,7 @@ import (
 	"os"
 	"path/filepath"
 	"strings"
+	"sync"
 
 	"github.com/sagernet/fswatch"
 	"github.com/sagernet/sing-box/adapter"
@@ -21,6 +22,7 @@ import (
 var _ adapter.CertificateStore = (*Store)(nil)
 
 type Store struct {
+	access                    sync.RWMutex
 	systemPool                *x509.CertPool
 	currentPool               *x509.CertPool
 	certificate               string
@@ -115,10 +117,14 @@ func (s *Store) Close() error {
 }
 
 func (s *Store) Pool() *x509.CertPool {
+	s.access.RLock()
+	defer s.access.RUnlock()
 	return s.currentPool
 }
 
 func (s *Store) update() error {
+	s.access.Lock()
+	defer s.access.Unlock()
 	var currentPool *x509.CertPool
 	if s.systemPool == nil {
 		currentPool = x509.NewCertPool()

+ 5 - 14
common/tls/ech.go

@@ -69,11 +69,7 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions,
 	} else {
 		return E.New("missing ECH keys")
 	}
-	block, rest := pem.Decode(echKey)
-	if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
-		return E.New("invalid ECH keys pem")
-	}
-	echKeys, err := UnmarshalECHKeys(block.Bytes)
+	echKeys, err := parseECHKeys(echKey)
 	if err != nil {
 		return E.Cause(err, "parse ECH keys")
 	}
@@ -85,21 +81,16 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions,
 	return nil
 }
 
-func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
-	echKey, err := os.ReadFile(echKeyPath)
-	if err != nil {
-		return E.Cause(err, "reload ECH keys from ", echKeyPath)
-	}
+func parseECHKeys(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
 	block, _ := pem.Decode(echKey)
 	if block == nil || block.Type != "ECH KEYS" {
-		return E.New("invalid ECH keys pem")
+		return nil, E.New("invalid ECH keys pem")
 	}
 	echKeys, err := UnmarshalECHKeys(block.Bytes)
 	if err != nil {
-		return E.Cause(err, "parse ECH keys")
+		return nil, E.Cause(err, "parse ECH keys")
 	}
-	tlsConfig.EncryptedClientHelloKeys = echKeys
-	return nil
+	return echKeys, nil
 }
 
 type ECHClientConfig struct {

+ 43 - 10
common/tls/std_server.go

@@ -6,6 +6,7 @@ import (
 	"net"
 	"os"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/sagernet/fswatch"
@@ -20,6 +21,7 @@ import (
 var errInsecureUnused = E.New("tls: insecure unused")
 
 type STDServerConfig struct {
+	access          sync.RWMutex
 	config          *tls.Config
 	logger          log.Logger
 	acmeService     adapter.SimpleLifecycle
@@ -32,14 +34,22 @@ type STDServerConfig struct {
 }
 
 func (c *STDServerConfig) ServerName() string {
+	c.access.RLock()
+	defer c.access.RUnlock()
 	return c.config.ServerName
 }
 
 func (c *STDServerConfig) SetServerName(serverName string) {
-	c.config.ServerName = serverName
+	c.access.Lock()
+	defer c.access.Unlock()
+	config := c.config.Clone()
+	config.ServerName = serverName
+	c.config = config
 }
 
 func (c *STDServerConfig) NextProtos() []string {
+	c.access.RLock()
+	defer c.access.RUnlock()
 	if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
 		return c.config.NextProtos[1:]
 	} else {
@@ -48,11 +58,15 @@ func (c *STDServerConfig) NextProtos() []string {
 }
 
 func (c *STDServerConfig) SetNextProtos(nextProto []string) {
+	c.access.Lock()
+	defer c.access.Unlock()
+	config := c.config.Clone()
 	if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
-		c.config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
+		config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
 	} else {
-		c.config.NextProtos = nextProto
+		config.NextProtos = nextProto
 	}
+	c.config = config
 }
 
 func (c *STDServerConfig) Config() (*STDConfig, error) {
@@ -77,9 +91,6 @@ func (c *STDServerConfig) Start() error {
 	if c.acmeService != nil {
 		return c.acmeService.Start()
 	} else {
-		if c.certificatePath == "" && c.keyPath == "" {
-			return nil
-		}
 		err := c.startWatcher()
 		if err != nil {
 			c.logger.Warn("create fsnotify watcher: ", err)
@@ -99,6 +110,9 @@ func (c *STDServerConfig) startWatcher() error {
 	if c.echKeyPath != "" {
 		watchPath = append(watchPath, c.echKeyPath)
 	}
+	if len(watchPath) == 0 {
+		return nil
+	}
 	watcher, err := fswatch.NewWatcher(fswatch.Options{
 		Path: watchPath,
 		Callback: func(path string) {
@@ -138,13 +152,26 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
 		if err != nil {
 			return E.Cause(err, "reload key pair")
 		}
-		c.config.Certificates = []tls.Certificate{keyPair}
+		c.access.Lock()
+		config := c.config.Clone()
+		config.Certificates = []tls.Certificate{keyPair}
+		c.config = config
+		c.access.Unlock()
 		c.logger.Info("reloaded TLS certificate")
 	} else if path == c.echKeyPath {
-		err := reloadECHKeys(c.echKeyPath, c.config)
+		echKey, err := os.ReadFile(c.echKeyPath)
+		if err != nil {
+			return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
+		}
+		echKeys, err := parseECHKeys(echKey)
 		if err != nil {
 			return err
 		}
+		c.access.Lock()
+		config := c.config.Clone()
+		config.EncryptedClientHelloKeys = echKeys
+		c.config = config
+		c.access.Unlock()
 		c.logger.Info("reloaded ECH keys")
 	}
 	return nil
@@ -262,7 +289,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
 			return nil, err
 		}
 	}
-	return &STDServerConfig{
+	serverConfig := &STDServerConfig{
 		config:          tlsConfig,
 		logger:          logger,
 		acmeService:     acmeService,
@@ -271,5 +298,11 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
 		certificatePath: options.CertificatePath,
 		keyPath:         options.KeyPath,
 		echKeyPath:      echKeyPath,
-	}, nil
+	}
+	serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
+		serverConfig.access.Lock()
+		defer serverConfig.access.Unlock()
+		return serverConfig.config, nil
+	}
+	return serverConfig, nil
 }

+ 1 - 1
dns/client.go

@@ -280,7 +280,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 		}
 	}
 	logExchangedResponse(c.logger, ctx, response, timeToLive)
-	return response, err
+	return response, nil
 }
 
 func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {

+ 1 - 0
dns/rcode.go

@@ -5,6 +5,7 @@ import (
 )
 
 const (
+	RcodeSuccess     RcodeError = mDNS.RcodeSuccess
 	RcodeFormatError RcodeError = mDNS.RcodeFormatError
 	RcodeNameError   RcodeError = mDNS.RcodeNameError
 	RcodeRefused     RcodeError = mDNS.RcodeRefused

+ 1 - 1
dns/transport/dhcp/dhcp_shared.go

@@ -43,7 +43,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr,
 			if response.Rcode != mDNS.RcodeSuccess {
 				err = dns.RcodeError(response.Rcode)
 			} else if len(dns.MessageToAddresses(response)) == 0 {
-				err = E.New(fqdn, ": empty result")
+				err = dns.RcodeSuccess
 			}
 		}
 		select {

+ 1 - 1
dns/transport/local/local.go

@@ -95,7 +95,7 @@ func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfi
 			if response.Rcode != mDNS.RcodeSuccess {
 				err = dns.RcodeError(response.Rcode)
 			} else if len(dns.MessageToAddresses(response)) == 0 {
-				err = E.New(fqdn, ": empty result")
+				err = dns.RcodeSuccess
 			}
 		}
 		select {