|
@@ -6,6 +6,7 @@ import (
|
|
|
"net"
|
|
|
"os"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/sagernet/fswatch"
|
|
@@ -20,6 +21,7 @@ import (
|
|
|
var errInsecureUnused = E.New("tls: insecure unused")
|
|
|
|
|
|
type STDServerConfig struct {
|
|
|
+ access sync.RWMutex
|
|
|
config *tls.Config
|
|
|
logger log.Logger
|
|
|
acmeService adapter.SimpleLifecycle
|
|
@@ -32,14 +34,22 @@ type STDServerConfig struct {
|
|
|
}
|
|
|
|
|
|
func (c *STDServerConfig) ServerName() string {
|
|
|
+ c.access.RLock()
|
|
|
+ defer c.access.RUnlock()
|
|
|
return c.config.ServerName
|
|
|
}
|
|
|
|
|
|
func (c *STDServerConfig) SetServerName(serverName string) {
|
|
|
- c.config.ServerName = serverName
|
|
|
+ c.access.Lock()
|
|
|
+ defer c.access.Unlock()
|
|
|
+ config := c.config.Clone()
|
|
|
+ config.ServerName = serverName
|
|
|
+ c.config = config
|
|
|
}
|
|
|
|
|
|
func (c *STDServerConfig) NextProtos() []string {
|
|
|
+ c.access.RLock()
|
|
|
+ defer c.access.RUnlock()
|
|
|
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
|
|
return c.config.NextProtos[1:]
|
|
|
} else {
|
|
@@ -48,11 +58,15 @@ func (c *STDServerConfig) NextProtos() []string {
|
|
|
}
|
|
|
|
|
|
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
|
|
+ c.access.Lock()
|
|
|
+ defer c.access.Unlock()
|
|
|
+ config := c.config.Clone()
|
|
|
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
|
|
- c.config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
|
|
+ config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
|
|
} else {
|
|
|
- c.config.NextProtos = nextProto
|
|
|
+ config.NextProtos = nextProto
|
|
|
}
|
|
|
+ c.config = config
|
|
|
}
|
|
|
|
|
|
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
|
@@ -77,9 +91,6 @@ func (c *STDServerConfig) Start() error {
|
|
|
if c.acmeService != nil {
|
|
|
return c.acmeService.Start()
|
|
|
} else {
|
|
|
- if c.certificatePath == "" && c.keyPath == "" {
|
|
|
- return nil
|
|
|
- }
|
|
|
err := c.startWatcher()
|
|
|
if err != nil {
|
|
|
c.logger.Warn("create fsnotify watcher: ", err)
|
|
@@ -99,6 +110,9 @@ func (c *STDServerConfig) startWatcher() error {
|
|
|
if c.echKeyPath != "" {
|
|
|
watchPath = append(watchPath, c.echKeyPath)
|
|
|
}
|
|
|
+ if len(watchPath) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
|
|
Path: watchPath,
|
|
|
Callback: func(path string) {
|
|
@@ -138,13 +152,26 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
|
|
|
if err != nil {
|
|
|
return E.Cause(err, "reload key pair")
|
|
|
}
|
|
|
- c.config.Certificates = []tls.Certificate{keyPair}
|
|
|
+ c.access.Lock()
|
|
|
+ config := c.config.Clone()
|
|
|
+ config.Certificates = []tls.Certificate{keyPair}
|
|
|
+ c.config = config
|
|
|
+ c.access.Unlock()
|
|
|
c.logger.Info("reloaded TLS certificate")
|
|
|
} else if path == c.echKeyPath {
|
|
|
- err := reloadECHKeys(c.echKeyPath, c.config)
|
|
|
+ echKey, err := os.ReadFile(c.echKeyPath)
|
|
|
+ if err != nil {
|
|
|
+ return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
|
|
|
+ }
|
|
|
+ echKeys, err := parseECHKeys(echKey)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ c.access.Lock()
|
|
|
+ config := c.config.Clone()
|
|
|
+ config.EncryptedClientHelloKeys = echKeys
|
|
|
+ c.config = config
|
|
|
+ c.access.Unlock()
|
|
|
c.logger.Info("reloaded ECH keys")
|
|
|
}
|
|
|
return nil
|
|
@@ -262,7 +289,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
- return &STDServerConfig{
|
|
|
+ serverConfig := &STDServerConfig{
|
|
|
config: tlsConfig,
|
|
|
logger: logger,
|
|
|
acmeService: acmeService,
|
|
@@ -271,5 +298,11 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
|
|
certificatePath: options.CertificatePath,
|
|
|
keyPath: options.KeyPath,
|
|
|
echKeyPath: echKeyPath,
|
|
|
- }, nil
|
|
|
+ }
|
|
|
+ serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
|
+ serverConfig.access.Lock()
|
|
|
+ defer serverConfig.access.Unlock()
|
|
|
+ return serverConfig.config, nil
|
|
|
+ }
|
|
|
+ return serverConfig, nil
|
|
|
}
|