1
0

std_server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. package tls
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "net"
  7. "os"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/sagernet/fswatch"
  12. "github.com/sagernet/sing-box/adapter"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing-box/option"
  16. "github.com/sagernet/sing/common"
  17. E "github.com/sagernet/sing/common/exceptions"
  18. "github.com/sagernet/sing/common/ntp"
  19. )
  20. var errInsecureUnused = E.New("tls: insecure unused")
  21. type STDServerConfig struct {
  22. access sync.RWMutex
  23. config *tls.Config
  24. logger log.Logger
  25. acmeService adapter.SimpleLifecycle
  26. certificate []byte
  27. key []byte
  28. certificatePath string
  29. keyPath string
  30. clientCertificatePath []string
  31. echKeyPath string
  32. watcher *fswatch.Watcher
  33. }
  34. func (c *STDServerConfig) ServerName() string {
  35. c.access.RLock()
  36. defer c.access.RUnlock()
  37. return c.config.ServerName
  38. }
  39. func (c *STDServerConfig) SetServerName(serverName string) {
  40. c.access.Lock()
  41. defer c.access.Unlock()
  42. config := c.config.Clone()
  43. config.ServerName = serverName
  44. c.config = config
  45. }
  46. func (c *STDServerConfig) NextProtos() []string {
  47. c.access.RLock()
  48. defer c.access.RUnlock()
  49. if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
  50. return c.config.NextProtos[1:]
  51. } else {
  52. return c.config.NextProtos
  53. }
  54. }
  55. func (c *STDServerConfig) SetNextProtos(nextProto []string) {
  56. c.access.Lock()
  57. defer c.access.Unlock()
  58. config := c.config.Clone()
  59. if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
  60. config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
  61. } else {
  62. config.NextProtos = nextProto
  63. }
  64. c.config = config
  65. }
  66. func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
  67. return c.config, nil
  68. }
  69. func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) {
  70. return tls.Client(conn, c.config), nil
  71. }
  72. func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
  73. return tls.Server(conn, c.config), nil
  74. }
  75. func (c *STDServerConfig) Clone() Config {
  76. return &STDServerConfig{
  77. config: c.config.Clone(),
  78. }
  79. }
  80. func (c *STDServerConfig) Start() error {
  81. if c.acmeService != nil {
  82. return c.acmeService.Start()
  83. } else {
  84. err := c.startWatcher()
  85. if err != nil {
  86. c.logger.Warn("create fsnotify watcher: ", err)
  87. }
  88. return nil
  89. }
  90. }
  91. func (c *STDServerConfig) startWatcher() error {
  92. var watchPath []string
  93. if c.certificatePath != "" {
  94. watchPath = append(watchPath, c.certificatePath)
  95. }
  96. if c.keyPath != "" {
  97. watchPath = append(watchPath, c.keyPath)
  98. }
  99. if c.echKeyPath != "" {
  100. watchPath = append(watchPath, c.echKeyPath)
  101. }
  102. if len(c.clientCertificatePath) > 0 {
  103. watchPath = append(watchPath, c.clientCertificatePath...)
  104. }
  105. if len(watchPath) == 0 {
  106. return nil
  107. }
  108. watcher, err := fswatch.NewWatcher(fswatch.Options{
  109. Path: watchPath,
  110. Callback: func(path string) {
  111. err := c.certificateUpdated(path)
  112. if err != nil {
  113. c.logger.Error(E.Cause(err, "reload certificate"))
  114. }
  115. },
  116. })
  117. if err != nil {
  118. return err
  119. }
  120. err = watcher.Start()
  121. if err != nil {
  122. return err
  123. }
  124. c.watcher = watcher
  125. return nil
  126. }
  127. func (c *STDServerConfig) certificateUpdated(path string) error {
  128. if path == c.certificatePath || path == c.keyPath {
  129. if path == c.certificatePath {
  130. certificate, err := os.ReadFile(c.certificatePath)
  131. if err != nil {
  132. return E.Cause(err, "reload certificate from ", c.certificatePath)
  133. }
  134. c.certificate = certificate
  135. } else if path == c.keyPath {
  136. key, err := os.ReadFile(c.keyPath)
  137. if err != nil {
  138. return E.Cause(err, "reload key from ", c.keyPath)
  139. }
  140. c.key = key
  141. }
  142. keyPair, err := tls.X509KeyPair(c.certificate, c.key)
  143. if err != nil {
  144. return E.Cause(err, "reload key pair")
  145. }
  146. c.access.Lock()
  147. config := c.config.Clone()
  148. config.Certificates = []tls.Certificate{keyPair}
  149. c.config = config
  150. c.access.Unlock()
  151. c.logger.Info("reloaded TLS certificate")
  152. } else if common.Contains(c.clientCertificatePath, path) {
  153. clientCertificateCA := x509.NewCertPool()
  154. var reloaded bool
  155. for _, certPath := range c.clientCertificatePath {
  156. content, err := os.ReadFile(certPath)
  157. if err != nil {
  158. c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
  159. continue
  160. }
  161. if !clientCertificateCA.AppendCertsFromPEM(content) {
  162. c.logger.Error(E.New("invalid client certificate file: ", certPath))
  163. continue
  164. }
  165. reloaded = true
  166. }
  167. if !reloaded {
  168. return E.New("client certificates is empty")
  169. }
  170. c.access.Lock()
  171. config := c.config.Clone()
  172. config.ClientCAs = clientCertificateCA
  173. c.config = config
  174. c.access.Unlock()
  175. c.logger.Info("reloaded client certificates")
  176. } else if path == c.echKeyPath {
  177. echKey, err := os.ReadFile(c.echKeyPath)
  178. if err != nil {
  179. return E.Cause(err, "reload ECH keys from ", c.echKeyPath)
  180. }
  181. err = c.setECHServerConfig(echKey)
  182. if err != nil {
  183. return err
  184. }
  185. c.logger.Info("reloaded ECH keys")
  186. }
  187. return nil
  188. }
  189. func (c *STDServerConfig) Close() error {
  190. if c.acmeService != nil {
  191. return c.acmeService.Close()
  192. }
  193. if c.watcher != nil {
  194. return c.watcher.Close()
  195. }
  196. return nil
  197. }
  198. func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
  199. if !options.Enabled {
  200. return nil, nil
  201. }
  202. var tlsConfig *tls.Config
  203. var acmeService adapter.SimpleLifecycle
  204. var err error
  205. if options.ACME != nil && len(options.ACME.Domain) > 0 {
  206. //nolint:staticcheck
  207. tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
  208. if err != nil {
  209. return nil, err
  210. }
  211. if options.Insecure {
  212. return nil, errInsecureUnused
  213. }
  214. } else {
  215. tlsConfig = &tls.Config{}
  216. }
  217. tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
  218. if options.ServerName != "" {
  219. tlsConfig.ServerName = options.ServerName
  220. }
  221. if len(options.ALPN) > 0 {
  222. tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
  223. }
  224. if options.MinVersion != "" {
  225. minVersion, err := ParseTLSVersion(options.MinVersion)
  226. if err != nil {
  227. return nil, E.Cause(err, "parse min_version")
  228. }
  229. tlsConfig.MinVersion = minVersion
  230. }
  231. if options.MaxVersion != "" {
  232. maxVersion, err := ParseTLSVersion(options.MaxVersion)
  233. if err != nil {
  234. return nil, E.Cause(err, "parse max_version")
  235. }
  236. tlsConfig.MaxVersion = maxVersion
  237. }
  238. if options.CipherSuites != nil {
  239. find:
  240. for _, cipherSuite := range options.CipherSuites {
  241. for _, tlsCipherSuite := range tls.CipherSuites() {
  242. if cipherSuite == tlsCipherSuite.Name {
  243. tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
  244. continue find
  245. }
  246. }
  247. return nil, E.New("unknown cipher_suite: ", cipherSuite)
  248. }
  249. }
  250. for _, curveID := range options.CurvePreferences {
  251. tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
  252. }
  253. tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
  254. var (
  255. certificate []byte
  256. key []byte
  257. )
  258. if acmeService == nil {
  259. if len(options.Certificate) > 0 {
  260. certificate = []byte(strings.Join(options.Certificate, "\n"))
  261. } else if options.CertificatePath != "" {
  262. content, err := os.ReadFile(options.CertificatePath)
  263. if err != nil {
  264. return nil, E.Cause(err, "read certificate")
  265. }
  266. certificate = content
  267. }
  268. if len(options.Key) > 0 {
  269. key = []byte(strings.Join(options.Key, "\n"))
  270. } else if options.KeyPath != "" {
  271. content, err := os.ReadFile(options.KeyPath)
  272. if err != nil {
  273. return nil, E.Cause(err, "read key")
  274. }
  275. key = content
  276. }
  277. if certificate == nil && key == nil && options.Insecure {
  278. timeFunc := ntp.TimeFuncFromContext(ctx)
  279. if timeFunc == nil {
  280. timeFunc = time.Now
  281. }
  282. tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
  283. return GenerateKeyPair(nil, nil, timeFunc, info.ServerName)
  284. }
  285. } else {
  286. if certificate == nil {
  287. return nil, E.New("missing certificate")
  288. } else if key == nil {
  289. return nil, E.New("missing key")
  290. }
  291. keyPair, err := tls.X509KeyPair(certificate, key)
  292. if err != nil {
  293. return nil, E.Cause(err, "parse x509 key pair")
  294. }
  295. tlsConfig.Certificates = []tls.Certificate{keyPair}
  296. }
  297. }
  298. if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
  299. if tlsConfig.ClientAuth == tls.NoClientCert {
  300. tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
  301. }
  302. }
  303. if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  304. if len(options.ClientCertificate) > 0 {
  305. clientCertificateCA := x509.NewCertPool()
  306. if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
  307. return nil, E.New("invalid client certificate strings")
  308. }
  309. tlsConfig.ClientCAs = clientCertificateCA
  310. } else if len(options.ClientCertificatePath) > 0 {
  311. clientCertificateCA := x509.NewCertPool()
  312. for _, path := range options.ClientCertificatePath {
  313. content, err := os.ReadFile(path)
  314. if err != nil {
  315. return nil, E.Cause(err, "read client certificate from ", path)
  316. }
  317. if !clientCertificateCA.AppendCertsFromPEM(content) {
  318. return nil, E.New("invalid client certificate file: ", path)
  319. }
  320. }
  321. tlsConfig.ClientCAs = clientCertificateCA
  322. } else if len(options.ClientCertificatePublicKeySHA256) > 0 {
  323. if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
  324. tlsConfig.ClientAuth = tls.RequireAnyClientCert
  325. } else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
  326. tlsConfig.ClientAuth = tls.RequestClientCert
  327. }
  328. tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  329. return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
  330. }
  331. } else {
  332. return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
  333. }
  334. }
  335. var echKeyPath string
  336. if options.ECH != nil && options.ECH.Enabled {
  337. err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
  338. if err != nil {
  339. return nil, err
  340. }
  341. }
  342. serverConfig := &STDServerConfig{
  343. config: tlsConfig,
  344. logger: logger,
  345. acmeService: acmeService,
  346. certificate: certificate,
  347. key: key,
  348. certificatePath: options.CertificatePath,
  349. clientCertificatePath: options.ClientCertificatePath,
  350. keyPath: options.KeyPath,
  351. echKeyPath: echKeyPath,
  352. }
  353. serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
  354. serverConfig.access.Lock()
  355. defer serverConfig.access.Unlock()
  356. return serverConfig.config, nil
  357. }
  358. var config ServerConfig = serverConfig
  359. if options.KernelTx || options.KernelRx {
  360. if !C.IsLinux {
  361. return nil, E.New("kTLS is only supported on Linux")
  362. }
  363. config = &KTlSServerConfig{
  364. ServerConfig: config,
  365. logger: logger,
  366. kernelTx: options.KernelTx,
  367. kernelRx: options.KernelRx,
  368. }
  369. }
  370. return config, nil
  371. }