tlsutils.go 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package common
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "fmt"
  6. "io/ioutil"
  7. "path/filepath"
  8. "sync"
  9. "github.com/drakkan/sftpgo/logger"
  10. "github.com/drakkan/sftpgo/utils"
  11. )
  12. // CertManager defines a TLS certificate manager
  13. type CertManager struct {
  14. certPath string
  15. keyPath string
  16. sync.RWMutex
  17. cert *tls.Certificate
  18. rootCAs *x509.CertPool
  19. }
  20. // LoadCertificate loads the configured x509 key pair
  21. func (m *CertManager) LoadCertificate(logSender string) error {
  22. newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath)
  23. if err != nil {
  24. logger.Warn(logSender, "", "unable to load X509 key pair, cert file %#v key file %#v error: %v",
  25. m.certPath, m.keyPath, err)
  26. return err
  27. }
  28. logger.Debug(logSender, "", "TLS certificate %#v successfully loaded", m.certPath)
  29. m.Lock()
  30. defer m.Unlock()
  31. m.cert = &newCert
  32. return nil
  33. }
  34. // GetCertificateFunc returns the loaded certificate
  35. func (m *CertManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
  36. return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  37. m.RLock()
  38. defer m.RUnlock()
  39. return m.cert, nil
  40. }
  41. }
  42. // GetRootCAs returns the set of root certificate authorities that servers
  43. // use if required to verify a client certificate
  44. func (m *CertManager) GetRootCAs() *x509.CertPool {
  45. return m.rootCAs
  46. }
  47. // LoadRootCAs tries to load root CA certificate authorities from the given paths
  48. func (m *CertManager) LoadRootCAs(caCertificates []string, configDir string) error {
  49. if len(caCertificates) == 0 {
  50. return nil
  51. }
  52. rootCAs := x509.NewCertPool()
  53. for _, rootCA := range caCertificates {
  54. if !utils.IsFileInputValid(rootCA) {
  55. return fmt.Errorf("invalid root CA certificate %#v", rootCA)
  56. }
  57. if rootCA != "" && !filepath.IsAbs(rootCA) {
  58. rootCA = filepath.Join(configDir, rootCA)
  59. }
  60. crt, err := ioutil.ReadFile(rootCA)
  61. if err != nil {
  62. return err
  63. }
  64. if rootCAs.AppendCertsFromPEM(crt) {
  65. logger.Debug(logSender, "", "TLS certificate authority %#v successfully loaded", rootCA)
  66. } else {
  67. err := fmt.Errorf("unable to load TLS certificate authority %#v", rootCA)
  68. logger.Debug(logSender, "", "%v", err)
  69. return err
  70. }
  71. }
  72. m.rootCAs = rootCAs
  73. return nil
  74. }
  75. // NewCertManager creates a new certificate manager
  76. func NewCertManager(certificateFile, certificateKeyFile, logSender string) (*CertManager, error) {
  77. manager := &CertManager{
  78. cert: nil,
  79. certPath: certificateFile,
  80. keyPath: certificateKeyFile,
  81. }
  82. err := manager.LoadCertificate(logSender)
  83. if err != nil {
  84. return nil, err
  85. }
  86. return manager, nil
  87. }