store.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package certificate
  2. import (
  3. "context"
  4. "crypto/x509"
  5. "encoding/base64"
  6. "io/fs"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "github.com/sagernet/fswatch"
  11. "github.com/sagernet/sing-box/adapter"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing-box/experimental/libbox/platform"
  14. "github.com/sagernet/sing-box/option"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. "github.com/sagernet/sing/common/logger"
  17. "github.com/sagernet/sing/service"
  18. "software.sslmate.com/src/go-pkcs12"
  19. )
  20. var _ adapter.CertificateStore = (*Store)(nil)
  21. type Store struct {
  22. systemPool *x509.CertPool
  23. currentPool *x509.CertPool
  24. certificate string
  25. certificatePaths []string
  26. certificateDirectoryPaths []string
  27. watcher *fswatch.Watcher
  28. tlsDecryptionEnabled bool
  29. tlsDecryptionPrivateKey any
  30. tlsDecryptionCertificate *x509.Certificate
  31. }
  32. func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) {
  33. var systemPool *x509.CertPool
  34. switch options.Store {
  35. case C.CertificateStoreSystem, "":
  36. systemPool = x509.NewCertPool()
  37. platformInterface := service.FromContext[platform.Interface](ctx)
  38. var systemValid bool
  39. if platformInterface != nil {
  40. for _, cert := range platformInterface.SystemCertificates() {
  41. if systemPool.AppendCertsFromPEM([]byte(cert)) {
  42. systemValid = true
  43. }
  44. }
  45. }
  46. if !systemValid {
  47. certPool, err := x509.SystemCertPool()
  48. if err != nil {
  49. return nil, err
  50. }
  51. systemPool = certPool
  52. }
  53. case C.CertificateStoreMozilla:
  54. systemPool = mozillaIncluded
  55. case C.CertificateStoreNone:
  56. systemPool = nil
  57. default:
  58. return nil, E.New("unknown certificate store: ", options.Store)
  59. }
  60. store := &Store{
  61. systemPool: systemPool,
  62. certificate: strings.Join(options.Certificate, "\n"),
  63. certificatePaths: options.CertificatePath,
  64. certificateDirectoryPaths: options.CertificateDirectoryPath,
  65. }
  66. var watchPaths []string
  67. for _, target := range options.CertificatePath {
  68. watchPaths = append(watchPaths, target)
  69. }
  70. for _, target := range options.CertificateDirectoryPath {
  71. watchPaths = append(watchPaths, target)
  72. }
  73. if len(watchPaths) > 0 {
  74. watcher, err := fswatch.NewWatcher(fswatch.Options{
  75. Path: watchPaths,
  76. Logger: logger,
  77. Callback: func(_ string) {
  78. err := store.update()
  79. if err != nil {
  80. logger.Error(E.Cause(err, "reload certificates"))
  81. }
  82. },
  83. })
  84. if err != nil {
  85. return nil, E.Cause(err, "fswatch: create fsnotify watcher")
  86. }
  87. store.watcher = watcher
  88. }
  89. err := store.update()
  90. if err != nil {
  91. return nil, E.Cause(err, "initializing certificate store")
  92. }
  93. if options.TLSDecryption != nil && options.TLSDecryption.Enabled {
  94. pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryption.KeyPair)
  95. if err != nil {
  96. return nil, E.Cause(err, "decode key pair base64 bytes")
  97. }
  98. privateKey, certificate, err := pkcs12.Decode(pfxBytes, options.TLSDecryption.KeyPairPassword)
  99. if err != nil {
  100. return nil, E.Cause(err, "decode key pair")
  101. }
  102. store.tlsDecryptionEnabled = true
  103. store.tlsDecryptionPrivateKey = privateKey
  104. store.tlsDecryptionCertificate = certificate
  105. }
  106. return store, nil
  107. }
  108. func (s *Store) Name() string {
  109. return "certificate"
  110. }
  111. func (s *Store) Start(stage adapter.StartStage) error {
  112. if stage != adapter.StartStateStart {
  113. return nil
  114. }
  115. if s.watcher != nil {
  116. return s.watcher.Start()
  117. }
  118. return nil
  119. }
  120. func (s *Store) Close() error {
  121. if s.watcher != nil {
  122. return s.watcher.Close()
  123. }
  124. return nil
  125. }
  126. func (s *Store) Pool() *x509.CertPool {
  127. return s.currentPool
  128. }
  129. func (s *Store) update() error {
  130. var currentPool *x509.CertPool
  131. if s.systemPool == nil {
  132. currentPool = x509.NewCertPool()
  133. } else {
  134. currentPool = s.systemPool.Clone()
  135. }
  136. if s.certificate != "" {
  137. if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
  138. return E.New("invalid certificate PEM strings")
  139. }
  140. }
  141. for _, path := range s.certificatePaths {
  142. pemContent, err := os.ReadFile(path)
  143. if err != nil {
  144. return err
  145. }
  146. if !currentPool.AppendCertsFromPEM(pemContent) {
  147. return E.New("invalid certificate PEM file: ", path)
  148. }
  149. }
  150. var firstErr error
  151. for _, directoryPath := range s.certificateDirectoryPaths {
  152. directoryEntries, err := readUniqueDirectoryEntries(directoryPath)
  153. if err != nil {
  154. if firstErr == nil && !os.IsNotExist(err) {
  155. firstErr = E.Cause(err, "invalid certificate directory: ", directoryPath)
  156. }
  157. continue
  158. }
  159. for _, directoryEntry := range directoryEntries {
  160. pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
  161. if err == nil {
  162. currentPool.AppendCertsFromPEM(pemContent)
  163. }
  164. }
  165. }
  166. if firstErr != nil {
  167. return firstErr
  168. }
  169. s.currentPool = currentPool
  170. return nil
  171. }
  172. func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
  173. files, err := os.ReadDir(dir)
  174. if err != nil {
  175. return nil, err
  176. }
  177. uniq := files[:0]
  178. for _, f := range files {
  179. if !isSameDirSymlink(f, dir) {
  180. uniq = append(uniq, f)
  181. }
  182. }
  183. return uniq, nil
  184. }
  185. func isSameDirSymlink(f fs.DirEntry, dir string) bool {
  186. if f.Type()&fs.ModeSymlink == 0 {
  187. return false
  188. }
  189. target, err := os.Readlink(filepath.Join(dir, f.Name()))
  190. return err == nil && !strings.Contains(target, "/")
  191. }
  192. func (s *Store) TLSDecryptionEnabled() bool {
  193. return s.tlsDecryptionEnabled
  194. }
  195. func (s *Store) TLSDecryptionCertificate() *x509.Certificate {
  196. return s.tlsDecryptionCertificate
  197. }
  198. func (s *Store) TLSDecryptionPrivateKey() any {
  199. return s.tlsDecryptionPrivateKey
  200. }