std_server.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. package tls
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "net"
  7. "os"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/sagernet/fswatch"
  12. "github.com/sagernet/sing-box/adapter"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/experimental/deprecated"
  15. "github.com/sagernet/sing-box/log"
  16. "github.com/sagernet/sing-box/option"
  17. "github.com/sagernet/sing/common"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. "github.com/sagernet/sing/common/ntp"
  20. "github.com/sagernet/sing/service"
  21. )
  22. var errInsecureUnused = E.New("tls: insecure unused")
  23. type managedCertificateProvider interface {
  24. adapter.CertificateProvider
  25. adapter.SimpleLifecycle
  26. }
  27. type sharedCertificateProvider struct {
  28. tag string
  29. manager adapter.CertificateProviderManager
  30. provider adapter.CertificateProviderService
  31. }
  32. func (p *sharedCertificateProvider) Start() error {
  33. provider, found := p.manager.Get(p.tag)
  34. if !found {
  35. return E.New("certificate provider not found: ", p.tag)
  36. }
  37. p.provider = provider
  38. return nil
  39. }
  40. func (p *sharedCertificateProvider) Close() error {
  41. return nil
  42. }
  43. func (p *sharedCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  44. return p.provider.GetCertificate(hello)
  45. }
  46. func (p *sharedCertificateProvider) GetACMENextProtos() []string {
  47. return getACMENextProtos(p.provider)
  48. }
  49. type inlineCertificateProvider struct {
  50. provider adapter.CertificateProviderService
  51. }
  52. func (p *inlineCertificateProvider) Start() error {
  53. for _, stage := range adapter.ListStartStages {
  54. err := adapter.LegacyStart(p.provider, stage)
  55. if err != nil {
  56. return err
  57. }
  58. }
  59. return nil
  60. }
  61. func (p *inlineCertificateProvider) Close() error {
  62. return p.provider.Close()
  63. }
  64. func (p *inlineCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  65. return p.provider.GetCertificate(hello)
  66. }
  67. func (p *inlineCertificateProvider) GetACMENextProtos() []string {
  68. return getACMENextProtos(p.provider)
  69. }
  70. func getACMENextProtos(provider adapter.CertificateProvider) []string {
  71. if acmeProvider, isACME := provider.(adapter.ACMECertificateProvider); isACME {
  72. return acmeProvider.GetACMENextProtos()
  73. }
  74. return nil
  75. }
  76. type STDServerConfig struct {
  77. access sync.RWMutex
  78. config *tls.Config
  79. logger log.Logger
  80. certificateProvider managedCertificateProvider
  81. acmeService adapter.SimpleLifecycle
  82. certificate []byte
  83. key []byte
  84. certificatePath string
  85. keyPath string
  86. clientCertificatePath []string
  87. echKeyPath string
  88. watcher *fswatch.Watcher
  89. }
  90. func (c *STDServerConfig) ServerName() string {
  91. c.access.RLock()
  92. defer c.access.RUnlock()
  93. return c.config.ServerName
  94. }
  95. func (c *STDServerConfig) SetServerName(serverName string) {
  96. c.access.Lock()
  97. defer c.access.Unlock()
  98. config := c.config.Clone()
  99. config.ServerName = serverName
  100. c.config = config
  101. }
  102. func (c *STDServerConfig) NextProtos() []string {
  103. c.access.RLock()
  104. defer c.access.RUnlock()
  105. if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
  106. return c.config.NextProtos[1:]
  107. }
  108. return c.config.NextProtos
  109. }
  110. func (c *STDServerConfig) SetNextProtos(nextProto []string) {
  111. c.access.Lock()
  112. defer c.access.Unlock()
  113. config := c.config.Clone()
  114. if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
  115. config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
  116. } else {
  117. config.NextProtos = nextProto
  118. }
  119. c.config = config
  120. }
  121. func (c *STDServerConfig) hasACMEALPN() bool {
  122. if c.acmeService != nil {
  123. return true
  124. }
  125. if c.certificateProvider != nil {
  126. if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
  127. return len(acmeProvider.GetACMENextProtos()) > 0
  128. }
  129. }
  130. return false
  131. }
  132. func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
  133. return c.config, nil
  134. }
  135. func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) {
  136. return tls.Client(conn, c.config), nil
  137. }
  138. func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
  139. return tls.Server(conn, c.config), nil
  140. }
  141. func (c *STDServerConfig) Clone() Config {
  142. return &STDServerConfig{
  143. config: c.config.Clone(),
  144. }
  145. }
  146. func (c *STDServerConfig) Start() error {
  147. if c.certificateProvider != nil {
  148. err := c.certificateProvider.Start()
  149. if err != nil {
  150. return err
  151. }
  152. if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
  153. nextProtos := acmeProvider.GetACMENextProtos()
  154. if len(nextProtos) > 0 {
  155. c.access.Lock()
  156. config := c.config.Clone()
  157. mergedNextProtos := append([]string{}, nextProtos...)
  158. for _, nextProto := range config.NextProtos {
  159. if !common.Contains(mergedNextProtos, nextProto) {
  160. mergedNextProtos = append(mergedNextProtos, nextProto)
  161. }
  162. }
  163. config.NextProtos = mergedNextProtos
  164. c.config = config
  165. c.access.Unlock()
  166. }
  167. }
  168. }
  169. if c.acmeService != nil {
  170. err := c.acmeService.Start()
  171. if err != nil {
  172. return err
  173. }
  174. }
  175. err := c.startWatcher()
  176. if err != nil {
  177. c.logger.Warn("create fsnotify watcher: ", err)
  178. }
  179. return nil
  180. }
  181. func (c *STDServerConfig) startWatcher() error {
  182. var watchPath []string
  183. if c.certificatePath != "" {
  184. watchPath = append(watchPath, c.certificatePath)
  185. }
  186. if c.keyPath != "" {
  187. watchPath = append(watchPath, c.keyPath)
  188. }
  189. if c.echKeyPath != "" {
  190. watchPath = append(watchPath, c.echKeyPath)
  191. }
  192. if len(c.clientCertificatePath) > 0 {
  193. watchPath = append(watchPath, c.clientCertificatePath...)
  194. }
  195. if len(watchPath) == 0 {
  196. return nil
  197. }
  198. watcher, err := fswatch.NewWatcher(fswatch.Options{
  199. Path: watchPath,
  200. Callback: func(path string) {
  201. err := c.certificateUpdated(path)
  202. if err != nil {
  203. c.logger.Error(E.Cause(err, "reload certificate"))
  204. }
  205. },
  206. })
  207. if err != nil {
  208. return err
  209. }
  210. err = watcher.Start()
  211. if err != nil {
  212. return err
  213. }
  214. c.watcher = watcher
  215. return nil
  216. }
  217. func (c *STDServerConfig) certificateUpdated(path string) error {
  218. if path == c.certificatePath || path == c.keyPath {
  219. if path == c.certificatePath {
  220. certificate, err := os.ReadFile(c.certificatePath)
  221. if err != nil {
  222. return E.Cause(err, "reload certificate from ", c.certificatePath)
  223. }
  224. c.certificate = certificate
  225. } else if path == c.keyPath {
  226. key, err := os.ReadFile(c.keyPath)
  227. if err != nil {
  228. return E.Cause(err, "reload key from ", c.keyPath)
  229. }
  230. c.key = key
  231. }
  232. keyPair, err := tls.X509KeyPair(c.certificate, c.key)
  233. if err != nil {
  234. return E.Cause(err, "reload key pair")
  235. }
  236. c.access.Lock()
  237. config := c.config.Clone()
  238. config.Certificates = []tls.Certificate{keyPair}
  239. c.config = config
  240. c.access.Unlock()
  241. c.logger.Info("reloaded TLS certificate")
  242. } else if common.Contains(c.clientCertificatePath, path) {
  243. clientCertificateCA := x509.NewCertPool()
  244. var reloaded bool
  245. for _, certPath := range c.clientCertificatePath {
  246. content, err := os.ReadFile(certPath)
  247. if err != nil {
  248. c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
  249. continue
  250. }
  251. if !clientCertificateCA.AppendCertsFromPEM(content) {
  252. c.logger.Error(E.New("invalid client certificate file: ", certPath))
  253. continue
  254. }
  255. reloaded = true
  256. }
  257. if !reloaded {
  258. return E.New("client certificates is empty")
  259. }
  260. c.access.Lock()
  261. config := c.config.Clone()
  262. config.ClientCAs = clientCertificateCA
  263. c.config = config
  264. c.access.Unlock()
  265. c.logger.Info("reloaded client certificates")
  266. } else if path == c.echKeyPath {
  267. echKey, err := os.ReadFile(c.echKeyPath)
  268. if err != nil {
  269. return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
  270. }
  271. err = c.setECHServerConfig(echKey)
  272. if err != nil {
  273. return err
  274. }
  275. c.logger.Info("reloaded ECH keys")
  276. }
  277. return nil
  278. }
  279. func (c *STDServerConfig) Close() error {
  280. return common.Close(c.certificateProvider, c.acmeService, c.watcher)
  281. }
  282. func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
  283. if !options.Enabled {
  284. return nil, nil
  285. }
  286. //nolint:staticcheck
  287. if options.CertificateProvider != nil && options.ACME != nil {
  288. return nil, E.New("certificate_provider and acme are mutually exclusive")
  289. }
  290. var tlsConfig *tls.Config
  291. var certificateProvider managedCertificateProvider
  292. var acmeService adapter.SimpleLifecycle
  293. var err error
  294. if options.CertificateProvider != nil {
  295. certificateProvider, err = newCertificateProvider(ctx, logger, options.CertificateProvider)
  296. if err != nil {
  297. return nil, err
  298. }
  299. tlsConfig = &tls.Config{
  300. GetCertificate: certificateProvider.GetCertificate,
  301. }
  302. if options.Insecure {
  303. return nil, errInsecureUnused
  304. }
  305. } else if options.ACME != nil && len(options.ACME.Domain) > 0 { //nolint:staticcheck
  306. deprecated.Report(ctx, deprecated.OptionInlineACME)
  307. //nolint:staticcheck
  308. tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
  309. if err != nil {
  310. return nil, err
  311. }
  312. if options.Insecure {
  313. return nil, errInsecureUnused
  314. }
  315. } else {
  316. tlsConfig = &tls.Config{}
  317. }
  318. tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
  319. if options.ServerName != "" {
  320. tlsConfig.ServerName = options.ServerName
  321. }
  322. if len(options.ALPN) > 0 {
  323. tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
  324. }
  325. if options.MinVersion != "" {
  326. minVersion, err := ParseTLSVersion(options.MinVersion)
  327. if err != nil {
  328. return nil, E.Cause(err, "parse min_version")
  329. }
  330. tlsConfig.MinVersion = minVersion
  331. }
  332. if options.MaxVersion != "" {
  333. maxVersion, err := ParseTLSVersion(options.MaxVersion)
  334. if err != nil {
  335. return nil, E.Cause(err, "parse max_version")
  336. }
  337. tlsConfig.MaxVersion = maxVersion
  338. }
  339. if options.CipherSuites != nil {
  340. find:
  341. for _, cipherSuite := range options.CipherSuites {
  342. for _, tlsCipherSuite := range tls.CipherSuites() {
  343. if cipherSuite == tlsCipherSuite.Name {
  344. tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
  345. continue find
  346. }
  347. }
  348. return nil, E.New("unknown cipher_suite: ", cipherSuite)
  349. }
  350. }
  351. for _, curveID := range options.CurvePreferences {
  352. tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
  353. }
  354. tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
  355. var (
  356. certificate []byte
  357. key []byte
  358. )
  359. if certificateProvider == nil && acmeService == nil {
  360. if len(options.Certificate) > 0 {
  361. certificate = []byte(strings.Join(options.Certificate, "\n"))
  362. } else if options.CertificatePath != "" {
  363. content, err := os.ReadFile(options.CertificatePath)
  364. if err != nil {
  365. return nil, E.Cause(err, "read certificate")
  366. }
  367. certificate = content
  368. }
  369. if len(options.Key) > 0 {
  370. key = []byte(strings.Join(options.Key, "\n"))
  371. } else if options.KeyPath != "" {
  372. content, err := os.ReadFile(options.KeyPath)
  373. if err != nil {
  374. return nil, E.Cause(err, "read key")
  375. }
  376. key = content
  377. }
  378. if certificate == nil && key == nil && options.Insecure {
  379. timeFunc := ntp.TimeFuncFromContext(ctx)
  380. if timeFunc == nil {
  381. timeFunc = time.Now
  382. }
  383. tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
  384. return GenerateKeyPair(nil, nil, timeFunc, info.ServerName)
  385. }
  386. } else {
  387. if certificate == nil {
  388. return nil, E.New("missing certificate")
  389. } else if key == nil {
  390. return nil, E.New("missing key")
  391. }
  392. keyPair, err := tls.X509KeyPair(certificate, key)
  393. if err != nil {
  394. return nil, E.Cause(err, "parse x509 key pair")
  395. }
  396. tlsConfig.Certificates = []tls.Certificate{keyPair}
  397. }
  398. }
  399. if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
  400. if tlsConfig.ClientAuth == tls.NoClientCert {
  401. tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
  402. }
  403. }
  404. if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  405. if len(options.ClientCertificate) > 0 {
  406. clientCertificateCA := x509.NewCertPool()
  407. if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
  408. return nil, E.New("invalid client certificate strings")
  409. }
  410. tlsConfig.ClientCAs = clientCertificateCA
  411. } else if len(options.ClientCertificatePath) > 0 {
  412. clientCertificateCA := x509.NewCertPool()
  413. for _, path := range options.ClientCertificatePath {
  414. content, err := os.ReadFile(path)
  415. if err != nil {
  416. return nil, E.Cause(err, "read client certificate from ", path)
  417. }
  418. if !clientCertificateCA.AppendCertsFromPEM(content) {
  419. return nil, E.New("invalid client certificate file: ", path)
  420. }
  421. }
  422. tlsConfig.ClientCAs = clientCertificateCA
  423. } else if len(options.ClientCertificatePublicKeySHA256) > 0 {
  424. if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  425. tlsConfig.ClientAuth = tls.RequireAnyClientCert
  426. } else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
  427. tlsConfig.ClientAuth = tls.RequestClientCert
  428. }
  429. tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  430. return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
  431. }
  432. } else {
  433. return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
  434. }
  435. }
  436. var echKeyPath string
  437. if options.ECH != nil && options.ECH.Enabled {
  438. err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
  439. if err != nil {
  440. return nil, err
  441. }
  442. }
  443. serverConfig := &STDServerConfig{
  444. config: tlsConfig,
  445. logger: logger,
  446. certificateProvider: certificateProvider,
  447. acmeService: acmeService,
  448. certificate: certificate,
  449. key: key,
  450. certificatePath: options.CertificatePath,
  451. clientCertificatePath: options.ClientCertificatePath,
  452. keyPath: options.KeyPath,
  453. echKeyPath: echKeyPath,
  454. }
  455. serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
  456. serverConfig.access.RLock()
  457. defer serverConfig.access.RUnlock()
  458. return serverConfig.config, nil
  459. }
  460. var config ServerConfig = serverConfig
  461. if options.KernelTx || options.KernelRx {
  462. if !C.IsLinux {
  463. return nil, E.New("kTLS is only supported on Linux")
  464. }
  465. config = &KTlSServerConfig{
  466. ServerConfig: config,
  467. logger: logger,
  468. kernelTx: options.KernelTx,
  469. kernelRx: options.KernelRx,
  470. }
  471. }
  472. return config, nil
  473. }
  474. func newCertificateProvider(ctx context.Context, logger log.ContextLogger, options *option.CertificateProviderOptions) (managedCertificateProvider, error) {
  475. if options.IsShared() {
  476. manager := service.FromContext[adapter.CertificateProviderManager](ctx)
  477. if manager == nil {
  478. return nil, E.New("missing certificate provider manager in context")
  479. }
  480. return &sharedCertificateProvider{
  481. tag: options.Tag,
  482. manager: manager,
  483. }, nil
  484. }
  485. registry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
  486. if registry == nil {
  487. return nil, E.New("missing certificate provider registry in context")
  488. }
  489. provider, err := registry.Create(ctx, logger, "", options.Type, options.Options)
  490. if err != nil {
  491. return nil, E.Cause(err, "create inline certificate provider")
  492. }
  493. return &inlineCertificateProvider{
  494. provider: provider,
  495. }, nil
  496. }