tls.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package inbound
  2. import (
  3. "crypto/tls"
  4. "os"
  5. "github.com/sagernet/sing-box/adapter"
  6. "github.com/sagernet/sing-box/log"
  7. "github.com/sagernet/sing-box/option"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. "github.com/fsnotify/fsnotify"
  10. )
  11. var _ adapter.Service = (*TLSConfig)(nil)
  12. type TLSConfig struct {
  13. config *tls.Config
  14. logger log.Logger
  15. certificate []byte
  16. key []byte
  17. certificatePath string
  18. keyPath string
  19. watcher *fsnotify.Watcher
  20. }
  21. func (c *TLSConfig) Config() *tls.Config {
  22. return c.config
  23. }
  24. func (c *TLSConfig) Start() error {
  25. if c.certificatePath == "" && c.keyPath == "" {
  26. return nil
  27. }
  28. err := c.startWatcher()
  29. if err != nil {
  30. c.logger.Warn("create fsnotify watcher: ", err)
  31. }
  32. return nil
  33. }
  34. func (c *TLSConfig) startWatcher() error {
  35. watcher, err := fsnotify.NewWatcher()
  36. if err != nil {
  37. return err
  38. }
  39. if c.certificatePath != "" {
  40. err = watcher.Add(c.certificatePath)
  41. if err != nil {
  42. return err
  43. }
  44. }
  45. if c.keyPath != "" {
  46. err = watcher.Add(c.keyPath)
  47. if err != nil {
  48. return err
  49. }
  50. }
  51. c.watcher = watcher
  52. go c.loopUpdate()
  53. return nil
  54. }
  55. func (c *TLSConfig) loopUpdate() {
  56. for {
  57. select {
  58. case event, ok := <-c.watcher.Events:
  59. if !ok {
  60. return
  61. }
  62. if event.Op&fsnotify.Write != fsnotify.Write {
  63. continue
  64. }
  65. err := c.reloadKeyPair()
  66. if err != nil {
  67. c.logger.Error(E.Cause(err, "reload TLS key pair"))
  68. }
  69. case err, ok := <-c.watcher.Errors:
  70. if !ok {
  71. return
  72. }
  73. c.logger.Error(E.Cause(err, "fsnotify error"))
  74. }
  75. }
  76. }
  77. func (c *TLSConfig) reloadKeyPair() error {
  78. if c.certificatePath != "" {
  79. certificate, err := os.ReadFile(c.certificatePath)
  80. if err != nil {
  81. return E.Cause(err, "reload certificate from ", c.certificatePath)
  82. }
  83. c.certificate = certificate
  84. }
  85. if c.keyPath != "" {
  86. key, err := os.ReadFile(c.keyPath)
  87. if err != nil {
  88. return E.Cause(err, "reload key from ", c.keyPath)
  89. }
  90. c.key = key
  91. }
  92. keyPair, err := tls.X509KeyPair(c.certificate, c.key)
  93. if err != nil {
  94. return E.Cause(err, "reload key pair")
  95. }
  96. c.config.Certificates = []tls.Certificate{keyPair}
  97. c.logger.Info("reloaded TLS certificate")
  98. return nil
  99. }
  100. func (c *TLSConfig) Close() error {
  101. if c.watcher != nil {
  102. return c.watcher.Close()
  103. }
  104. return nil
  105. }
  106. func NewTLSConfig(logger log.Logger, options option.InboundTLSOptions) (*TLSConfig, error) {
  107. if !options.Enabled {
  108. return nil, nil
  109. }
  110. var tlsConfig tls.Config
  111. if options.ServerName != "" {
  112. tlsConfig.ServerName = options.ServerName
  113. }
  114. if len(options.ALPN) > 0 {
  115. tlsConfig.NextProtos = options.ALPN
  116. }
  117. if options.MinVersion != "" {
  118. minVersion, err := option.ParseTLSVersion(options.MinVersion)
  119. if err != nil {
  120. return nil, E.Cause(err, "parse min_version")
  121. }
  122. tlsConfig.MinVersion = minVersion
  123. }
  124. if options.MaxVersion != "" {
  125. maxVersion, err := option.ParseTLSVersion(options.MaxVersion)
  126. if err != nil {
  127. return nil, E.Cause(err, "parse max_version")
  128. }
  129. tlsConfig.MaxVersion = maxVersion
  130. }
  131. if options.CipherSuites != nil {
  132. find:
  133. for _, cipherSuite := range options.CipherSuites {
  134. for _, tlsCipherSuite := range tls.CipherSuites() {
  135. if cipherSuite == tlsCipherSuite.Name {
  136. tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
  137. continue find
  138. }
  139. }
  140. return nil, E.New("unknown cipher_suite: ", cipherSuite)
  141. }
  142. }
  143. var certificate []byte
  144. if options.Certificate != "" {
  145. certificate = []byte(options.Certificate)
  146. } else if options.CertificatePath != "" {
  147. content, err := os.ReadFile(options.CertificatePath)
  148. if err != nil {
  149. return nil, E.Cause(err, "read certificate")
  150. }
  151. certificate = content
  152. }
  153. var key []byte
  154. if options.Key != "" {
  155. key = []byte(options.Key)
  156. } else if options.KeyPath != "" {
  157. content, err := os.ReadFile(options.KeyPath)
  158. if err != nil {
  159. return nil, E.Cause(err, "read key")
  160. }
  161. key = content
  162. }
  163. if certificate == nil {
  164. return nil, E.New("missing certificate")
  165. }
  166. if key == nil {
  167. return nil, E.New("missing key")
  168. }
  169. keyPair, err := tls.X509KeyPair(certificate, key)
  170. if err != nil {
  171. return nil, E.Cause(err, "parse x509 key pair")
  172. }
  173. tlsConfig.Certificates = []tls.Certificate{keyPair}
  174. return &TLSConfig{
  175. config: &tlsConfig,
  176. logger: logger,
  177. certificate: certificate,
  178. key: key,
  179. certificatePath: options.CertificatePath,
  180. keyPath: options.KeyPath,
  181. }, nil
  182. }