Explorar o código

Add TLS certificate reload

世界 %!s(int64=3) %!d(string=hai) anos
pai
achega
70b4577dbe
Modificáronse 9 ficheiros con 186 adicións e 29 borrados
  1. 16 14
      common/process/searcher_android.go
  2. 5 1
      docs/configuration/shared/tls.md
  3. 1 1
      go.mod
  4. 2 2
      go.sum
  5. 21 3
      inbound/http.go
  6. 116 2
      inbound/tls.go
  7. 22 3
      inbound/vmess.go
  8. 1 1
      test/go.mod
  9. 2 2
      test/go.sum

+ 16 - 14
common/process/searcher_android.go

@@ -36,7 +36,7 @@ func (s *androidSearcher) Start() error {
 	}
 	err = s.startWatcher()
 	if err != nil {
-		s.logger.Debug("create fsnotify watcher: ", err)
+		s.logger.Warn("create fsnotify watcher: ", err)
 	}
 	return nil
 }
@@ -56,20 +56,22 @@ func (s *androidSearcher) startWatcher() error {
 }
 
 func (s *androidSearcher) loopUpdate() {
-	select {
-	case _, ok := <-s.watcher.Events:
-		if !ok {
-			return
-		}
-		err := s.updatePackages()
-		if err != nil {
-			s.logger.Error(E.Cause(err, "update packages list"))
-		}
-	case err, ok := <-s.watcher.Errors:
-		if !ok {
-			return
+	for {
+		select {
+		case _, ok := <-s.watcher.Events:
+			if !ok {
+				return
+			}
+			err := s.updatePackages()
+			if err != nil {
+				s.logger.Error(E.Cause(err, "update packages list"))
+			}
+		case err, ok := <-s.watcher.Errors:
+			if !ok {
+				return
+			}
+			s.logger.Error(E.Cause(err, "fsnotify error"))
 		}
-		s.logger.Error(E.Cause(err, "fsnotify error"))
 	}
 }
 

+ 5 - 1
docs/configuration/shared/tls.md

@@ -133,4 +133,8 @@ The server private key, in PEM format.
 
 ==Server only==
 
-The path to the server private key, in PEM format.
+The path to the server private key, in PEM format.
+
+### Reload
+
+For server configuration, certificate and key will be automatically reloaded if modified.

+ 1 - 1
go.mod

@@ -23,7 +23,7 @@ require (
 	go.uber.org/atomic v1.9.0
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
 	golang.org/x/net v0.0.0-20220728211354-c7608f3a8462
-	golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
+	golang.org/x/sys v0.0.0-20220730100132-1609e554cd39
 )
 
 require (

+ 2 - 2
go.sum

@@ -279,8 +279,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
-golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk=
+golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

+ 21 - 3
inbound/http.go

@@ -12,6 +12,7 @@ import (
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
+	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/protocol/http"
@@ -22,7 +23,7 @@ var _ adapter.Inbound = (*HTTP)(nil)
 type HTTP struct {
 	myInboundAdapter
 	authenticator auth.Authenticator
-	tlsConfig     *tls.Config
+	tlsConfig     *TLSConfig
 }
 
 func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) (*HTTP, error) {
@@ -40,7 +41,7 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge
 		authenticator: auth.NewAuthenticator(options.Users),
 	}
 	if options.TLS != nil {
-		tlsConfig, err := NewTLSConfig(common.PtrValueOrDefault(options.TLS))
+		tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS))
 		if err != nil {
 			return nil, err
 		}
@@ -50,9 +51,26 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge
 	return inbound, nil
 }
 
+func (h *HTTP) Start() error {
+	if h.tlsConfig != nil {
+		err := h.tlsConfig.Start()
+		if err != nil {
+			return E.Cause(err, "create TLS config")
+		}
+	}
+	return h.myInboundAdapter.Start()
+}
+
+func (h *HTTP) Close() error {
+	return common.Close(
+		&h.myInboundAdapter,
+		common.PtrOrNil(h.tlsConfig),
+	)
+}
+
 func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 	if h.tlsConfig != nil {
-		conn = tls.Server(conn, h.tlsConfig)
+		conn = tls.Server(conn, h.tlsConfig.Config())
 	}
 	return http.HandleConnection(ctx, conn, std_bufio.NewReader(conn), h.authenticator, h.upstreamUserHandler(metadata), M.Metadata{})
 }

+ 116 - 2
inbound/tls.go

@@ -4,11 +4,118 @@ import (
 	"crypto/tls"
 	"os"
 
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	E "github.com/sagernet/sing/common/exceptions"
+
+	"github.com/fsnotify/fsnotify"
 )
 
-func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) {
+var _ adapter.Service = (*TLSConfig)(nil)
+
+type TLSConfig struct {
+	config          *tls.Config
+	logger          log.Logger
+	certificate     []byte
+	key             []byte
+	certificatePath string
+	keyPath         string
+	watcher         *fsnotify.Watcher
+}
+
+func (c *TLSConfig) Config() *tls.Config {
+	return c.config
+}
+
+func (c *TLSConfig) Start() error {
+	if c.certificatePath == "" && c.keyPath == "" {
+		return nil
+	}
+	err := c.startWatcher()
+	if err != nil {
+		c.logger.Warn("create fsnotify watcher: ", err)
+	}
+	return nil
+}
+
+func (c *TLSConfig) startWatcher() error {
+	watcher, err := fsnotify.NewWatcher()
+	if err != nil {
+		return err
+	}
+	if c.certificatePath != "" {
+		err = watcher.Add(c.certificatePath)
+		if err != nil {
+			return err
+		}
+	}
+	if c.keyPath != "" {
+		err = watcher.Add(c.keyPath)
+		if err != nil {
+			return err
+		}
+	}
+	c.watcher = watcher
+	go c.loopUpdate()
+	return nil
+}
+
+func (c *TLSConfig) loopUpdate() {
+	for {
+		select {
+		case event, ok := <-c.watcher.Events:
+			if !ok {
+				return
+			}
+			if event.Op&fsnotify.Write != fsnotify.Write {
+				continue
+			}
+			err := c.reloadKeyPair()
+			if err != nil {
+				c.logger.Error(E.Cause(err, "reload TLS key pair"))
+			}
+		case err, ok := <-c.watcher.Errors:
+			if !ok {
+				return
+			}
+			c.logger.Error(E.Cause(err, "fsnotify error"))
+		}
+	}
+}
+
+func (c *TLSConfig) reloadKeyPair() error {
+	if c.certificatePath != "" {
+		certificate, err := os.ReadFile(c.certificatePath)
+		if err != nil {
+			return E.Cause(err, "reload certificate from ", c.certificatePath)
+		}
+		c.certificate = certificate
+	}
+	if c.keyPath != "" {
+		key, err := os.ReadFile(c.keyPath)
+		if err != nil {
+			return E.Cause(err, "reload key from ", c.keyPath)
+		}
+		c.key = key
+	}
+	keyPair, err := tls.X509KeyPair(c.certificate, c.key)
+	if err != nil {
+		return E.Cause(err, "reload key pair")
+	}
+	c.config.Certificates = []tls.Certificate{keyPair}
+	c.logger.Info("reloaded TLS certificate")
+	return nil
+}
+
+func (c *TLSConfig) Close() error {
+	if c.watcher != nil {
+		return c.watcher.Close()
+	}
+	return nil
+}
+
+func NewTLSConfig(logger log.Logger, options option.InboundTLSOptions) (*TLSConfig, error) {
 	if !options.Enabled {
 		return nil, nil
 	}
@@ -76,5 +183,12 @@ func NewTLSConfig(options option.InboundTLSOptions) (*tls.Config, error) {
 		return nil, E.Cause(err, "parse x509 key pair")
 	}
 	tlsConfig.Certificates = []tls.Certificate{keyPair}
-	return &tlsConfig, nil
+	return &TLSConfig{
+		config:          &tlsConfig,
+		logger:          logger,
+		certificate:     certificate,
+		key:             key,
+		certificatePath: options.CertificatePath,
+		keyPath:         options.KeyPath,
+	}, nil
 }

+ 22 - 3
inbound/vmess.go

@@ -13,6 +13,7 @@ import (
 	"github.com/sagernet/sing-vmess"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/auth"
+	E "github.com/sagernet/sing/common/exceptions"
 	F "github.com/sagernet/sing/common/format"
 	N "github.com/sagernet/sing/common/network"
 )
@@ -23,7 +24,7 @@ type VMess struct {
 	myInboundAdapter
 	service   *vmess.Service[int]
 	users     []option.VMessUser
-	tlsConfig *tls.Config
+	tlsConfig *TLSConfig
 }
 
 func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessInboundOptions) (*VMess, error) {
@@ -49,19 +50,37 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg
 		return nil, err
 	}
 	if options.TLS != nil {
-		inbound.tlsConfig, err = NewTLSConfig(common.PtrValueOrDefault(options.TLS))
+		tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS))
 		if err != nil {
 			return nil, err
 		}
+		inbound.tlsConfig = tlsConfig
 	}
 	inbound.service = service
 	inbound.connHandler = inbound
 	return inbound, nil
 }
 
+func (h *VMess) Start() error {
+	if h.tlsConfig != nil {
+		err := h.tlsConfig.Start()
+		if err != nil {
+			return E.Cause(err, "create TLS config")
+		}
+	}
+	return h.myInboundAdapter.Start()
+}
+
+func (h *VMess) Close() error {
+	return common.Close(
+		&h.myInboundAdapter,
+		common.PtrOrNil(h.tlsConfig),
+	)
+}
+
 func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 	if h.tlsConfig != nil {
-		conn = tls.Server(conn, h.tlsConfig)
+		conn = tls.Server(conn, h.tlsConfig.Config())
 	}
 	return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
 }

+ 1 - 1
test/go.mod

@@ -61,7 +61,7 @@ require (
 	go.uber.org/atomic v1.9.0 // indirect
 	golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
 	golang.org/x/mod v0.5.1 // indirect
-	golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
+	golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 // indirect
 	golang.org/x/text v0.3.7 // indirect
 	golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
 	golang.org/x/tools v0.1.9 // indirect

+ 2 - 2
test/go.sum

@@ -314,8 +314,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
-golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220730100132-1609e554cd39 h1:aNCnH+Fiqs7ZDTFH6oEFjIfbX2HvgQXJ6uQuUbTobjk=
+golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=