std_server.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. package tls
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "net"
  7. "os"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/sagernet/fswatch"
  12. "github.com/sagernet/sing-box/adapter"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/experimental/deprecated"
  15. "github.com/sagernet/sing-box/log"
  16. "github.com/sagernet/sing-box/option"
  17. "github.com/sagernet/sing/common"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. "github.com/sagernet/sing/common/ntp"
  20. "github.com/sagernet/sing/service"
  21. )
  22. var errInsecureUnused = E.New("tls: insecure unused")
  23. type managedCertificateProvider interface {
  24. adapter.CertificateProvider
  25. adapter.SimpleLifecycle
  26. }
  27. type sharedCertificateProvider struct {
  28. tag string
  29. manager adapter.CertificateProviderManager
  30. provider adapter.CertificateProviderService
  31. }
  32. func (p *sharedCertificateProvider) Start() error {
  33. provider, found := p.manager.Get(p.tag)
  34. if !found {
  35. return E.New("certificate provider not found: ", p.tag)
  36. }
  37. p.provider = provider
  38. return nil
  39. }
  40. func (p *sharedCertificateProvider) Close() error {
  41. return nil
  42. }
  43. func (p *sharedCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  44. return p.provider.GetCertificate(hello)
  45. }
  46. func (p *sharedCertificateProvider) GetACMENextProtos() []string {
  47. return getACMENextProtos(p.provider)
  48. }
  49. type inlineCertificateProvider struct {
  50. provider adapter.CertificateProviderService
  51. }
  52. func (p *inlineCertificateProvider) Start() error {
  53. for _, stage := range adapter.ListStartStages {
  54. err := adapter.LegacyStart(p.provider, stage)
  55. if err != nil {
  56. return err
  57. }
  58. }
  59. return nil
  60. }
  61. func (p *inlineCertificateProvider) Close() error {
  62. return p.provider.Close()
  63. }
  64. func (p *inlineCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  65. return p.provider.GetCertificate(hello)
  66. }
  67. func (p *inlineCertificateProvider) GetACMENextProtos() []string {
  68. return getACMENextProtos(p.provider)
  69. }
  70. func getACMENextProtos(provider adapter.CertificateProvider) []string {
  71. if acmeProvider, isACME := provider.(adapter.ACMECertificateProvider); isACME {
  72. return acmeProvider.GetACMENextProtos()
  73. }
  74. return nil
  75. }
  76. type STDServerConfig struct {
  77. access sync.RWMutex
  78. config *tls.Config
  79. handshakeTimeout time.Duration
  80. logger log.Logger
  81. certificateProvider managedCertificateProvider
  82. acmeService adapter.SimpleLifecycle
  83. certificate []byte
  84. key []byte
  85. certificatePath string
  86. keyPath string
  87. clientCertificatePath []string
  88. echKeyPath string
  89. watcher *fswatch.Watcher
  90. }
  91. func (c *STDServerConfig) ServerName() string {
  92. c.access.RLock()
  93. defer c.access.RUnlock()
  94. return c.config.ServerName
  95. }
  96. func (c *STDServerConfig) SetServerName(serverName string) {
  97. c.access.Lock()
  98. defer c.access.Unlock()
  99. config := c.config.Clone()
  100. config.ServerName = serverName
  101. c.config = config
  102. }
  103. func (c *STDServerConfig) NextProtos() []string {
  104. c.access.RLock()
  105. defer c.access.RUnlock()
  106. if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
  107. return c.config.NextProtos[1:]
  108. }
  109. return c.config.NextProtos
  110. }
  111. func (c *STDServerConfig) SetNextProtos(nextProto []string) {
  112. c.access.Lock()
  113. defer c.access.Unlock()
  114. config := c.config.Clone()
  115. if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
  116. config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
  117. } else {
  118. config.NextProtos = nextProto
  119. }
  120. c.config = config
  121. }
  122. func (c *STDServerConfig) HandshakeTimeout() time.Duration {
  123. c.access.RLock()
  124. defer c.access.RUnlock()
  125. return c.handshakeTimeout
  126. }
  127. func (c *STDServerConfig) SetHandshakeTimeout(timeout time.Duration) {
  128. c.access.Lock()
  129. defer c.access.Unlock()
  130. c.handshakeTimeout = timeout
  131. }
  132. func (c *STDServerConfig) hasACMEALPN() bool {
  133. if c.acmeService != nil {
  134. return true
  135. }
  136. if c.certificateProvider != nil {
  137. if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
  138. return len(acmeProvider.GetACMENextProtos()) > 0
  139. }
  140. }
  141. return false
  142. }
  143. func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
  144. return c.config, nil
  145. }
  146. func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) {
  147. return tls.Client(conn, c.config), nil
  148. }
  149. func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
  150. return tls.Server(conn, c.config), nil
  151. }
  152. func (c *STDServerConfig) Clone() Config {
  153. return &STDServerConfig{
  154. config: c.config.Clone(),
  155. handshakeTimeout: c.handshakeTimeout,
  156. }
  157. }
  158. func (c *STDServerConfig) Start() error {
  159. if c.certificateProvider != nil {
  160. err := c.certificateProvider.Start()
  161. if err != nil {
  162. return err
  163. }
  164. if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
  165. nextProtos := acmeProvider.GetACMENextProtos()
  166. if len(nextProtos) > 0 {
  167. c.access.Lock()
  168. config := c.config.Clone()
  169. mergedNextProtos := append([]string{}, nextProtos...)
  170. for _, nextProto := range config.NextProtos {
  171. if !common.Contains(mergedNextProtos, nextProto) {
  172. mergedNextProtos = append(mergedNextProtos, nextProto)
  173. }
  174. }
  175. config.NextProtos = mergedNextProtos
  176. c.config = config
  177. c.access.Unlock()
  178. }
  179. }
  180. }
  181. if c.acmeService != nil {
  182. err := c.acmeService.Start()
  183. if err != nil {
  184. return err
  185. }
  186. }
  187. err := c.startWatcher()
  188. if err != nil {
  189. c.logger.Warn("create fsnotify watcher: ", err)
  190. }
  191. return nil
  192. }
  193. func (c *STDServerConfig) startWatcher() error {
  194. var watchPath []string
  195. if c.certificatePath != "" {
  196. watchPath = append(watchPath, c.certificatePath)
  197. }
  198. if c.keyPath != "" {
  199. watchPath = append(watchPath, c.keyPath)
  200. }
  201. if c.echKeyPath != "" {
  202. watchPath = append(watchPath, c.echKeyPath)
  203. }
  204. if len(c.clientCertificatePath) > 0 {
  205. watchPath = append(watchPath, c.clientCertificatePath...)
  206. }
  207. if len(watchPath) == 0 {
  208. return nil
  209. }
  210. watcher, err := fswatch.NewWatcher(fswatch.Options{
  211. Path: watchPath,
  212. Callback: func(path string) {
  213. err := c.certificateUpdated(path)
  214. if err != nil {
  215. c.logger.Error(E.Cause(err, "reload certificate"))
  216. }
  217. },
  218. })
  219. if err != nil {
  220. return err
  221. }
  222. err = watcher.Start()
  223. if err != nil {
  224. return err
  225. }
  226. c.watcher = watcher
  227. return nil
  228. }
  229. func (c *STDServerConfig) certificateUpdated(path string) error {
  230. if path == c.certificatePath || path == c.keyPath {
  231. if path == c.certificatePath {
  232. certificate, err := os.ReadFile(c.certificatePath)
  233. if err != nil {
  234. return E.Cause(err, "reload certificate from ", c.certificatePath)
  235. }
  236. c.certificate = certificate
  237. } else if path == c.keyPath {
  238. key, err := os.ReadFile(c.keyPath)
  239. if err != nil {
  240. return E.Cause(err, "reload key from ", c.keyPath)
  241. }
  242. c.key = key
  243. }
  244. keyPair, err := tls.X509KeyPair(c.certificate, c.key)
  245. if err != nil {
  246. return E.Cause(err, "reload key pair")
  247. }
  248. c.access.Lock()
  249. config := c.config.Clone()
  250. config.Certificates = []tls.Certificate{keyPair}
  251. c.config = config
  252. c.access.Unlock()
  253. c.logger.Info("reloaded TLS certificate")
  254. } else if common.Contains(c.clientCertificatePath, path) {
  255. clientCertificateCA := x509.NewCertPool()
  256. var reloaded bool
  257. for _, certPath := range c.clientCertificatePath {
  258. content, err := os.ReadFile(certPath)
  259. if err != nil {
  260. c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
  261. continue
  262. }
  263. if !clientCertificateCA.AppendCertsFromPEM(content) {
  264. c.logger.Error(E.New("invalid client certificate file: ", certPath))
  265. continue
  266. }
  267. reloaded = true
  268. }
  269. if !reloaded {
  270. return E.New("client certificates is empty")
  271. }
  272. c.access.Lock()
  273. config := c.config.Clone()
  274. config.ClientCAs = clientCertificateCA
  275. c.config = config
  276. c.access.Unlock()
  277. c.logger.Info("reloaded client certificates")
  278. } else if path == c.echKeyPath {
  279. echKey, err := os.ReadFile(c.echKeyPath)
  280. if err != nil {
  281. return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
  282. }
  283. err = c.setECHServerConfig(echKey)
  284. if err != nil {
  285. return err
  286. }
  287. c.logger.Info("reloaded ECH keys")
  288. }
  289. return nil
  290. }
  291. func (c *STDServerConfig) Close() error {
  292. return common.Close(c.certificateProvider, c.acmeService, c.watcher)
  293. }
  294. func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
  295. if !options.Enabled {
  296. return nil, nil
  297. }
  298. //nolint:staticcheck
  299. if options.CertificateProvider != nil && options.ACME != nil {
  300. return nil, E.New("certificate_provider and acme are mutually exclusive")
  301. }
  302. var tlsConfig *tls.Config
  303. var certificateProvider managedCertificateProvider
  304. var acmeService adapter.SimpleLifecycle
  305. var err error
  306. if options.CertificateProvider != nil {
  307. certificateProvider, err = newCertificateProvider(ctx, logger, options.CertificateProvider)
  308. if err != nil {
  309. return nil, err
  310. }
  311. tlsConfig = &tls.Config{
  312. GetCertificate: certificateProvider.GetCertificate,
  313. }
  314. if options.Insecure {
  315. return nil, errInsecureUnused
  316. }
  317. } else if options.ACME != nil && len(options.ACME.Domain) > 0 { //nolint:staticcheck
  318. deprecated.Report(ctx, deprecated.OptionInlineACME)
  319. //nolint:staticcheck
  320. tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
  321. if err != nil {
  322. return nil, err
  323. }
  324. if options.Insecure {
  325. return nil, errInsecureUnused
  326. }
  327. } else {
  328. tlsConfig = &tls.Config{}
  329. }
  330. tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
  331. if options.ServerName != "" {
  332. tlsConfig.ServerName = options.ServerName
  333. }
  334. if len(options.ALPN) > 0 {
  335. tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
  336. }
  337. if options.MinVersion != "" {
  338. minVersion, err := ParseTLSVersion(options.MinVersion)
  339. if err != nil {
  340. return nil, E.Cause(err, "parse min_version")
  341. }
  342. tlsConfig.MinVersion = minVersion
  343. }
  344. if options.MaxVersion != "" {
  345. maxVersion, err := ParseTLSVersion(options.MaxVersion)
  346. if err != nil {
  347. return nil, E.Cause(err, "parse max_version")
  348. }
  349. tlsConfig.MaxVersion = maxVersion
  350. }
  351. if options.CipherSuites != nil {
  352. find:
  353. for _, cipherSuite := range options.CipherSuites {
  354. for _, tlsCipherSuite := range tls.CipherSuites() {
  355. if cipherSuite == tlsCipherSuite.Name {
  356. tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
  357. continue find
  358. }
  359. }
  360. return nil, E.New("unknown cipher_suite: ", cipherSuite)
  361. }
  362. }
  363. for _, curveID := range options.CurvePreferences {
  364. tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
  365. }
  366. tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
  367. var (
  368. certificate []byte
  369. key []byte
  370. )
  371. if certificateProvider == nil && acmeService == nil {
  372. if len(options.Certificate) > 0 {
  373. certificate = []byte(strings.Join(options.Certificate, "\n"))
  374. } else if options.CertificatePath != "" {
  375. content, err := os.ReadFile(options.CertificatePath)
  376. if err != nil {
  377. return nil, E.Cause(err, "read certificate")
  378. }
  379. certificate = content
  380. }
  381. if len(options.Key) > 0 {
  382. key = []byte(strings.Join(options.Key, "\n"))
  383. } else if options.KeyPath != "" {
  384. content, err := os.ReadFile(options.KeyPath)
  385. if err != nil {
  386. return nil, E.Cause(err, "read key")
  387. }
  388. key = content
  389. }
  390. if certificate == nil && key == nil && options.Insecure {
  391. timeFunc := ntp.TimeFuncFromContext(ctx)
  392. if timeFunc == nil {
  393. timeFunc = time.Now
  394. }
  395. tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
  396. return GenerateKeyPair(nil, nil, timeFunc, info.ServerName)
  397. }
  398. } else {
  399. if certificate == nil {
  400. return nil, E.New("missing certificate")
  401. } else if key == nil {
  402. return nil, E.New("missing key")
  403. }
  404. keyPair, err := tls.X509KeyPair(certificate, key)
  405. if err != nil {
  406. return nil, E.Cause(err, "parse x509 key pair")
  407. }
  408. tlsConfig.Certificates = []tls.Certificate{keyPair}
  409. }
  410. }
  411. if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
  412. if tlsConfig.ClientAuth == tls.NoClientCert {
  413. tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
  414. }
  415. }
  416. if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  417. if len(options.ClientCertificate) > 0 {
  418. clientCertificateCA := x509.NewCertPool()
  419. if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
  420. return nil, E.New("invalid client certificate strings")
  421. }
  422. tlsConfig.ClientCAs = clientCertificateCA
  423. } else if len(options.ClientCertificatePath) > 0 {
  424. clientCertificateCA := x509.NewCertPool()
  425. for _, path := range options.ClientCertificatePath {
  426. content, err := os.ReadFile(path)
  427. if err != nil {
  428. return nil, E.Cause(err, "read client certificate from ", path)
  429. }
  430. if !clientCertificateCA.AppendCertsFromPEM(content) {
  431. return nil, E.New("invalid client certificate file: ", path)
  432. }
  433. }
  434. tlsConfig.ClientCAs = clientCertificateCA
  435. } else if len(options.ClientCertificatePublicKeySHA256) > 0 {
  436. if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  437. tlsConfig.ClientAuth = tls.RequireAnyClientCert
  438. } else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
  439. tlsConfig.ClientAuth = tls.RequestClientCert
  440. }
  441. tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  442. return VerifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts)
  443. }
  444. } else {
  445. return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
  446. }
  447. }
  448. var echKeyPath string
  449. if options.ECH != nil && options.ECH.Enabled {
  450. err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
  451. if err != nil {
  452. return nil, err
  453. }
  454. }
  455. var handshakeTimeout time.Duration
  456. if options.HandshakeTimeout > 0 {
  457. handshakeTimeout = options.HandshakeTimeout.Build()
  458. } else {
  459. handshakeTimeout = C.TCPTimeout
  460. }
  461. serverConfig := &STDServerConfig{
  462. config: tlsConfig,
  463. handshakeTimeout: handshakeTimeout,
  464. logger: logger,
  465. certificateProvider: certificateProvider,
  466. acmeService: acmeService,
  467. certificate: certificate,
  468. key: key,
  469. certificatePath: options.CertificatePath,
  470. clientCertificatePath: options.ClientCertificatePath,
  471. keyPath: options.KeyPath,
  472. echKeyPath: echKeyPath,
  473. }
  474. serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
  475. serverConfig.access.RLock()
  476. defer serverConfig.access.RUnlock()
  477. return serverConfig.config, nil
  478. }
  479. var config ServerConfig = serverConfig
  480. if options.KernelTx || options.KernelRx {
  481. if !C.IsLinux {
  482. return nil, E.New("kTLS is only supported on Linux")
  483. }
  484. config = &KTlSServerConfig{
  485. ServerConfig: config,
  486. logger: logger,
  487. kernelTx: options.KernelTx,
  488. kernelRx: options.KernelRx,
  489. }
  490. }
  491. return config, nil
  492. }
  493. func newCertificateProvider(ctx context.Context, logger log.ContextLogger, options *option.CertificateProviderOptions) (managedCertificateProvider, error) {
  494. if options.IsShared() {
  495. manager := service.FromContext[adapter.CertificateProviderManager](ctx)
  496. if manager == nil {
  497. return nil, E.New("missing certificate provider manager in context")
  498. }
  499. return &sharedCertificateProvider{
  500. tag: options.Tag,
  501. manager: manager,
  502. }, nil
  503. }
  504. registry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
  505. if registry == nil {
  506. return nil, E.New("missing certificate provider registry in context")
  507. }
  508. provider, err := registry.Create(ctx, logger, "", options.Type, options.Options)
  509. if err != nil {
  510. return nil, E.Cause(err, "create inline certificate provider")
  511. }
  512. return &inlineCertificateProvider{
  513. provider: provider,
  514. }, nil
  515. }