Browse Source

web hooks: add mutual TLS support

Nicola Murino 4 years ago
parent
commit
a21ccad174

+ 4 - 1
cmd/startsubsys.go

@@ -90,7 +90,10 @@ Command-line flags should be specified in the Subsystem declaration.
 				os.Exit(1)
 			}
 			httpConfig := config.GetHTTPConfig()
-			httpConfig.Initialize(configDir)
+			if err := httpConfig.Initialize(configDir); err != nil {
+				logger.Error(logSender, connectionID, "unable to initialize http client: %v", err)
+				os.Exit(1)
+			}
 			user, err := dataprovider.UserExists(username)
 			if err == nil {
 				if user.HomeDir != filepath.Clean(homedir) && !preserveHomeDir {

+ 1 - 1
common/common_test.go

@@ -109,7 +109,7 @@ func TestMain(m *testing.M) {
 	httpConfig := httpclient.Config{
 		Timeout: 5,
 	}
-	httpConfig.Initialize(configDir)
+	httpConfig.Initialize(configDir) //nolint:errcheck
 
 	go func() {
 		// start a test HTTP server to receive action notifications

+ 24 - 0
config/config.go

@@ -222,6 +222,7 @@ func Init() {
 			RetryWaitMax:   30,
 			RetryMax:       3,
 			CACertificates: nil,
+			Certificates:   nil,
 			SkipTLSVerify:  false,
 		},
 		KMSConfig: kms.Configuration{
@@ -577,6 +578,7 @@ func loadBindingsFromEnv() {
 		getFTPDBindingFromEnv(idx)
 		getWebDAVDBindingFromEnv(idx)
 		getHTTPDBindingFromEnv(idx)
+		getHTTPClientCertificatesFromEnv(idx)
 	}
 }
 
@@ -756,6 +758,28 @@ func getHTTPDBindingFromEnv(idx int) {
 	}
 }
 
+func getHTTPClientCertificatesFromEnv(idx int) {
+	tlsCert := httpclient.TLSKeyPair{}
+
+	cert, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__CERT", idx))
+	if ok {
+		tlsCert.Cert = cert
+	}
+
+	key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__KEY", idx))
+	if ok {
+		tlsCert.Key = key
+	}
+
+	if tlsCert.Cert != "" && tlsCert.Key != "" {
+		if len(globalConf.HTTPConfig.Certificates) > idx {
+			globalConf.HTTPConfig.Certificates[idx] = tlsCert
+		} else {
+			globalConf.HTTPConfig.Certificates = append(globalConf.HTTPConfig.Certificates, tlsCert)
+		}
+	}
+}
+
 func setViperDefaults() {
 	viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout)
 	viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode)

+ 61 - 0
config/config_test.go

@@ -667,6 +667,67 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
 	require.Equal(t, 1, bindings[2].ClientAuthType)
 }
 
+func TestHTTPClientCertificatesFromEnv(t *testing.T) {
+	reset()
+
+	configDir := ".."
+	confName := tempConfigName + ".json"
+	configFilePath := filepath.Join(configDir, confName)
+	err := config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	httpConf := config.GetHTTPConfig()
+	httpConf.Certificates = append(httpConf.Certificates, httpclient.TLSKeyPair{
+		Cert: "cert",
+		Key:  "key",
+	})
+	c := make(map[string]httpclient.Config)
+	c["http"] = httpConf
+	jsonConf, err := json.Marshal(c)
+	require.NoError(t, err)
+	err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
+	require.NoError(t, err)
+	err = config.LoadConfig(configDir, confName)
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Certificates, 1)
+	require.Equal(t, "cert", config.GetHTTPConfig().Certificates[0].Cert)
+	require.Equal(t, "key", config.GetHTTPConfig().Certificates[0].Key)
+
+	os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__CERT", "cert0")
+	os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__KEY", "key0")
+	os.Setenv("SFTPGO_HTTP__CERTIFICATES__8__CERT", "cert8")
+	os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__CERT", "cert9")
+	os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__KEY", "key9")
+
+	t.Cleanup(func() {
+		os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__CERT")
+		os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__KEY")
+		os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__8__CERT")
+		os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__CERT")
+		os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__KEY")
+	})
+
+	err = config.LoadConfig(configDir, confName)
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Certificates, 2)
+	require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert)
+	require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key)
+	require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert)
+	require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key)
+
+	err = os.Remove(configFilePath)
+	assert.NoError(t, err)
+
+	config.Init()
+
+	err = config.LoadConfig(configDir, "")
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Certificates, 2)
+	require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert)
+	require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key)
+	require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert)
+	require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key)
+}
+
 func TestConfigFromEnv(t *testing.T) {
 	reset()
 

+ 3 - 0
docs/full-configuration.md

@@ -215,6 +215,9 @@ The configuration file contains the following sections:
   - `retry_wait_max`, integer. Defines the maximum waiting time between attempts in seconds. The backoff algorithm will perform exponential backoff based on the attempt number and limited by the provided minimum and maximum durations.
   - `retry_max`, integer. Defines the maximum number of retries if the first request fails.
   - `ca_certificates`, list of strings. List of paths to extra CA certificates to trust. The paths can be absolute or relative to the config dir. Adding trusted CA certificates is a convenient way to use self-signed certificates without defeating the purpose of using TLS.
+  - `certificates`, list of certificate for mutual TLS. Each certificate is a struct with the following fields:
+    - `cert`, string. Path to the certificate file. The path can be absolute or relative to the config dir.
+    - `key`, string. Path to the key file. The path can be absolute or relative to the config dir.
   - `skip_tls_verify`, boolean. if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
 - **kms**, configuration for the Key Management Service, more details can be found [here](./kms.md)
   - `secrets`

+ 1 - 1
ftpd/ftpd_test.go

@@ -140,7 +140,7 @@ func TestMain(m *testing.M) {
 	}
 
 	httpConfig := config.GetHTTPConfig()
-	httpConfig.Initialize(configDir)
+	httpConfig.Initialize(configDir) //nolint:errcheck
 
 	kmsConfig := config.GetKMSConfig()
 	err = kmsConfig.Initialize()

+ 60 - 15
httpclient/httpclient.go

@@ -3,6 +3,7 @@ package httpclient
 import (
 	"crypto/tls"
 	"crypto/x509"
+	"fmt"
 	"io/ioutil"
 	"net/http"
 	"path/filepath"
@@ -14,6 +15,12 @@ import (
 	"github.com/drakkan/sftpgo/utils"
 )
 
+// TLSKeyPair defines the paths for a TLS key pair
+type TLSKeyPair struct {
+	Cert string `json:"cert" mapstructure:"cert"`
+	Key  string `json:"key" mapstructure:"key"`
+}
+
 // Config defines the configuration for HTTP clients.
 // HTTP clients are used for executing hooks such as the ones used for
 // custom actions, external authentication and pre-login user modifications
@@ -31,6 +38,8 @@ type Config struct {
 	// Adding trusted CA certificates is a convenient way to use self-signed
 	// certificates without defeating the purpose of using TLS
 	CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"`
+	// Certificates defines the certificates to use for mutual TLS
+	Certificates []TLSKeyPair `json:"certificates" mapstructure:"certificates"`
 	// if enabled the HTTP client accepts any TLS certificate presented by
 	// the server and any host name in that certificate.
 	// In this mode, TLS is susceptible to man-in-the-middle attacks.
@@ -45,25 +54,35 @@ const logSender = "httpclient"
 var httpConfig Config
 
 // Initialize configures HTTP clients
-func (c Config) Initialize(configDir string) {
-	httpConfig = c
-	rootCAs := c.loadCACerts(configDir)
+func (c *Config) Initialize(configDir string) error {
+	rootCAs, err := c.loadCACerts(configDir)
+	if err != nil {
+		return err
+	}
 	customTransport := http.DefaultTransport.(*http.Transport).Clone()
 	if customTransport.TLSClientConfig != nil {
 		customTransport.TLSClientConfig.RootCAs = rootCAs
 	} else {
 		customTransport.TLSClientConfig = &tls.Config{
-			RootCAs: rootCAs,
+			RootCAs:    rootCAs,
+			NextProtos: []string{"h2", "http/1.1"},
 		}
 	}
 	customTransport.TLSClientConfig.InsecureSkipVerify = c.SkipTLSVerify
-	httpConfig.customTransport = customTransport
-	httpConfig.tlsConfig = customTransport.TLSClientConfig
+	c.customTransport = customTransport
+	c.tlsConfig = customTransport.TLSClientConfig
+
+	err = c.loadCertificates(configDir)
+	if err != nil {
+		return err
+	}
+	httpConfig = *c
+	return nil
 }
 
 // loadCACerts returns system cert pools and try to add the configured
 // CA certificates to it
-func (c Config) loadCACerts(configDir string) *x509.CertPool {
+func (c *Config) loadCACerts(configDir string) (*x509.CertPool, error) {
 	rootCAs, err := x509.SystemCertPool()
 	if err != nil {
 		rootCAs = x509.NewCertPool()
@@ -71,26 +90,52 @@ func (c Config) loadCACerts(configDir string) *x509.CertPool {
 
 	for _, ca := range c.CACertificates {
 		if !utils.IsFileInputValid(ca) {
-			logger.Warn(logSender, "", "unable to load invalid CA certificate: %#v", ca)
-			logger.WarnToConsole("unable to load invalid CA certificate: %#v", ca)
-			continue
+			return nil, fmt.Errorf("unable to load invalid CA certificate: %#v", ca)
 		}
 		if !filepath.IsAbs(ca) {
 			ca = filepath.Join(configDir, ca)
 		}
 		certs, err := ioutil.ReadFile(ca)
 		if err != nil {
-			logger.Warn(logSender, "", "unable to load CA certificate: %v", err)
-			logger.WarnToConsole("unable to load CA certificate: %#v", err)
+			return nil, fmt.Errorf("unable to load CA certificate: %v", err)
 		}
 		if rootCAs.AppendCertsFromPEM(certs) {
 			logger.Debug(logSender, "", "CA certificate %#v added to the trusted certificates", ca)
 		} else {
-			logger.Warn(logSender, "", "unable to add CA certificate %#v to the trusted cetificates", ca)
-			logger.WarnToConsole("unable to add CA certificate %#v to the trusted cetificates", ca)
+			return nil, fmt.Errorf("unable to add CA certificate %#v to the trusted cetificates", ca)
+		}
+	}
+	return rootCAs, nil
+}
+
+func (c *Config) loadCertificates(configDir string) error {
+	if len(c.Certificates) == 0 {
+		return nil
+	}
+
+	for _, keyPair := range c.Certificates {
+		cert := keyPair.Cert
+		key := keyPair.Key
+		if !utils.IsFileInputValid(cert) {
+			return fmt.Errorf("unable to load invalid certificate: %#v", cert)
+		}
+		if !utils.IsFileInputValid(key) {
+			return fmt.Errorf("unable to load invalid key: %#v", key)
+		}
+		if !filepath.IsAbs(cert) {
+			cert = filepath.Join(configDir, cert)
+		}
+		if !filepath.IsAbs(key) {
+			key = filepath.Join(configDir, key)
+		}
+		tlsCert, err := tls.LoadX509KeyPair(cert, key)
+		if err != nil {
+			return fmt.Errorf("unable to load key pair %#v, %#v: %v", cert, key, err)
 		}
+		logger.Debug(logSender, "", "client certificate %#v and key %#v successfully loaded", cert, key)
+		c.tlsConfig.Certificates = append(c.tlsConfig.Certificates, tlsCert)
 	}
-	return rootCAs
+	return nil
 }
 
 // GetHTTPClient returns an HTTP client with the configured parameters

+ 1 - 1
httpd/httpd_test.go

@@ -177,7 +177,7 @@ func TestMain(m *testing.M) {
 	}
 
 	httpConfig := config.GetHTTPConfig()
-	httpConfig.Initialize(configDir)
+	httpConfig.Initialize(configDir) //nolint:errcheck
 	kmsConfig := config.GetKMSConfig()
 	err = kmsConfig.Initialize()
 	if err != nil {

+ 12 - 3
service/service.go

@@ -47,8 +47,7 @@ type Service struct {
 	Error             error
 }
 
-// Start initializes the service
-func (s *Service) Start() error {
+func (s *Service) initLogger() {
 	logLevel := zerolog.DebugLevel
 	if !s.LogVerbose {
 		logLevel = zerolog.InfoLevel
@@ -63,6 +62,11 @@ func (s *Service) Start() error {
 			logger.DisableLogger()
 		}
 	}
+}
+
+// Start initializes the service
+func (s *Service) Start() error {
+	s.initLogger()
 	logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+
 		"log max age: %v log verbose: %v, log compress: %v, load data from: %#v", version.GetAsString(), s.ConfigDir, s.ConfigFile,
 		s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogVerbose, s.LogCompress, s.LoadDataFrom)
@@ -120,7 +124,12 @@ func (s *Service) Start() error {
 	}
 
 	httpConfig := config.GetHTTPConfig()
-	httpConfig.Initialize(s.ConfigDir)
+	err = httpConfig.Initialize(s.ConfigDir)
+	if err != nil {
+		logger.Error(logSender, "", "error initializing http client: %v", err)
+		logger.ErrorToConsole("error initializing http client: %v", err)
+		return err
+	}
 
 	s.startServices()
 

+ 1 - 1
sftpd/sftpd_test.go

@@ -184,7 +184,7 @@ func TestMain(m *testing.M) {
 	}
 
 	httpConfig := config.GetHTTPConfig()
-	httpConfig.Initialize(configDir)
+	httpConfig.Initialize(configDir) //nolint:errcheck
 	kmsConfig := config.GetKMSConfig()
 	err = kmsConfig.Initialize()
 	if err != nil {

+ 1 - 0
sftpgo.json

@@ -178,6 +178,7 @@
     "retry_wait_max": 30,
     "retry_max": 3,
     "ca_certificates": [],
+    "certificates": [],
     "skip_tls_verify": false
   },
   "kms": {

+ 1 - 1
webdavd/webdavd_test.go

@@ -133,7 +133,7 @@ func TestMain(m *testing.M) {
 	}
 
 	httpConfig := config.GetHTTPConfig()
-	httpConfig.Initialize(configDir)
+	httpConfig.Initialize(configDir) //nolint:errcheck
 	kmsConfig := config.GetKMSConfig()
 	err = kmsConfig.Initialize()
 	if err != nil {