浏览代码

ftpd: advertise TLS support only if really enabled

if we don't have a global TLS configuration, advertise TLS only on the
bindings where it is configured instead of failing at runtime

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 2 年之前
父节点
当前提交
a592e388cd

+ 2 - 2
docs/full-configuration.md

@@ -163,8 +163,8 @@ The configuration file contains the following sections:
     - `port`, integer. The port used for serving FTP requests. 0 means disabled. Default: 0.
     - `port`, integer. The port used for serving FTP requests. 0 means disabled. Default: 0.
     - `address`, string. Leave blank to listen on all available network interfaces. Default: "".
     - `address`, string. Leave blank to listen on all available network interfaces. Default: "".
     - `apply_proxy_config`, boolean. If enabled the common proxy configuration, if any, will be applied. Please note that we expect the proxy header on control and data connections. Default `true`.
     - `apply_proxy_config`, boolean. If enabled the common proxy configuration, if any, will be applied. Please note that we expect the proxy header on control and data connections. Default `true`.
-    - `tls_mode`, integer. 0 means accept both cleartext and encrypted sessions. 1 means TLS is required for both control and data connection. 2 means implicit TLS. Do not enable this blindly, please check that a proper TLS config is in place if you set `tls_mode` is different from 0.
-    - `tls_session_reuse`, integer. 0 means session reuse is not checked, clients may or may not resume TLS sessions. 1 means TLS session reuse is required for explicit FTPS. Not supported for implicit TLS. Default: `0`.
+    - `tls_mode`, integer. 0 means accept both cleartext and encrypted sessions. 1 means TLS is required for both control and data connection. 2 means implicit TLS.Please check that a proper TLS config is in place if you set `tls_mode` is different from 0.
+    - `tls_session_reuse`, integer. 0 means session reuse is not checked, clients may or may not reuse TLS sessions. 1 means TLS session reuse is required for explicit FTPS. Legacy reuse method based on session IDs is not supported, clients must use session tickets. Session reuse is not supported for implicit TLS. Default: `0`.
     - `certificate_file`, string. Binding specific TLS certificate. This can be an absolute path or a path relative to the config dir.
     - `certificate_file`, string. Binding specific TLS certificate. This can be an absolute path or a path relative to the config dir.
     - `certificate_key_file`, string. Binding specific private key matching the above certificate. This can be an absolute path or a path relative to the config dir. If not set the global ones will be used, if any.
     - `certificate_key_file`, string. Binding specific private key matching the above certificate. This can be an absolute path or a path relative to the config dir. If not set the global ones will be used, if any.
     - `min_tls_version`, integer. Defines the minimum version of TLS to be enabled. `12` means TLS 1.2 (and therefore TLS 1.2 and TLS 1.3 will be enabled),`13` means TLS 1.3. Default: `12`.
     - `min_tls_version`, integer. Defines the minimum version of TLS to be enabled. `12` means TLS 1.2 (and therefore TLS 1.2 and TLS 1.3 will be enabled),`13` means TLS 1.3. Default: `12`.

+ 10 - 1
internal/common/tlsutils.go

@@ -108,6 +108,15 @@ func (m *CertManager) loadCertificates() error {
 	return nil
 	return nil
 }
 }
 
 
+// HasCertificate returns true if there is a certificate for the specified certID
+func (m *CertManager) HasCertificate(certID string) bool {
+	m.RLock()
+	defer m.RUnlock()
+
+	_, ok := m.certs[certID]
+	return ok
+}
+
 // GetCertificateFunc returns the loaded certificate
 // GetCertificateFunc returns the loaded certificate
 func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
 func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
 	return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
 	return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -136,7 +145,7 @@ func (m *CertManager) IsRevoked(crt *x509.Certificate, caCrt *x509.Certificate)
 
 
 	for _, crl := range m.crls {
 	for _, crl := range m.crls {
 		if crl.CheckSignatureFrom(caCrt) == nil {
 		if crl.CheckSignatureFrom(caCrt) == nil {
-			for _, rc := range crl.RevokedCertificates {
+			for _, rc := range crl.RevokedCertificateEntries {
 				if rc.SerialNumber.Cmp(crt.SerialNumber) == 0 {
 				if rc.SerialNumber.Cmp(crt.SerialNumber) == 0 {
 					return true
 					return true
 				}
 				}

+ 2 - 0
internal/common/tlsutils_test.go

@@ -325,6 +325,8 @@ func TestLoadCertificate(t *testing.T) {
 
 
 	certManager, err = NewCertManager(keyPairs, configDir, logSenderTest)
 	certManager, err = NewCertManager(keyPairs, configDir, logSenderTest)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
+	assert.True(t, certManager.HasCertificate(DefaultTLSKeyPaidID))
+	assert.False(t, certManager.HasCertificate("unknownID"))
 	certFunc := certManager.GetCertificateFunc(DefaultTLSKeyPaidID)
 	certFunc := certManager.GetCertificateFunc(DefaultTLSKeyPaidID)
 	if assert.NotNil(t, certFunc) {
 	if assert.NotNil(t, certFunc) {
 		hello := &tls.ClientHelloInfo{
 		hello := &tls.ClientHelloInfo{

+ 4 - 1
internal/ftpd/ftpd.go

@@ -240,7 +240,10 @@ func (b *Binding) GetTLSDescription() string {
 		return "Implicit"
 		return "Implicit"
 	}
 	}
 
 
-	return "Plain and explicit"
+	if certMgr.HasCertificate(common.DefaultTLSKeyPaidID) || certMgr.HasCertificate(b.GetAddress()) {
+		return "Plain and explicit"
+	}
+	return "Disabled"
 }
 }
 
 
 // PortRange defines a port range
 // PortRange defines a port range

+ 35 - 11
internal/ftpd/internal_test.go

@@ -480,20 +480,37 @@ func TestInitialization(t *testing.T) {
 	err = ReloadCertificateMgr()
 	err = ReloadCertificateMgr()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	certMgr = oldMgr
-
 	binding = Binding{
 	binding = Binding{
 		Port:           2121,
 		Port:           2121,
 		ClientAuthType: 1,
 		ClientAuthType: 1,
 	}
 	}
+	assert.Equal(t, "Disabled", binding.GetTLSDescription())
+	certPath := filepath.Join(os.TempDir(), "test_ftpd.crt")
+	keyPath := filepath.Join(os.TempDir(), "test_ftpd.key")
+	binding.CertificateFile = certPath
+	binding.CertificateKeyFile = keyPath
+	keyPairs := []common.TLSKeyPair{
+		{
+			Cert: certPath,
+			Key:  keyPath,
+			ID:   binding.GetAddress(),
+		},
+	}
+	certMgr, err = common.NewCertManager(keyPairs, configDir, "")
+	require.NoError(t, err)
+
+	assert.Equal(t, "Plain and explicit", binding.GetTLSDescription())
 	server = NewServer(c, configDir, binding, 0)
 	server = NewServer(c, configDir, binding, 0)
 	cfg, err := server.GetTLSConfig()
 	cfg, err := server.GetTLSConfig()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, tls.RequireAndVerifyClientCert, cfg.ClientAuth)
 	assert.Equal(t, tls.RequireAndVerifyClientCert, cfg.ClientAuth)
+
+	certMgr = oldMgr
 }
 }
 
 
 func TestServerGetSettings(t *testing.T) {
 func TestServerGetSettings(t *testing.T) {
 	oldConfig := common.Config
 	oldConfig := common.Config
+	oldMgr := certMgr
 
 
 	binding := Binding{
 	binding := Binding{
 		Port:             2121,
 		Port:             2121,
@@ -518,7 +535,9 @@ func TestServerGetSettings(t *testing.T) {
 	assert.Error(t, err)
 	assert.Error(t, err)
 	server.binding.Port = 8021
 	server.binding.Port = 8021
 
 
-	assert.Equal(t, "Plain and explicit", binding.GetTLSDescription())
+	assert.Equal(t, "Disabled", binding.GetTLSDescription())
+	_, err = server.GetTLSConfig()
+	assert.Error(t, err) // TLS configured but cert manager has no certificate
 
 
 	binding.TLSMode = 1
 	binding.TLSMode = 1
 	assert.Equal(t, "Explicit required", binding.GetTLSDescription())
 	assert.Equal(t, "Explicit required", binding.GetTLSDescription())
@@ -526,13 +545,22 @@ func TestServerGetSettings(t *testing.T) {
 	binding.TLSMode = 2
 	binding.TLSMode = 2
 	assert.Equal(t, "Implicit", binding.GetTLSDescription())
 	assert.Equal(t, "Implicit", binding.GetTLSDescription())
 
 
-	certPath := filepath.Join(os.TempDir(), "test.crt")
-	keyPath := filepath.Join(os.TempDir(), "test.key")
+	certPath := filepath.Join(os.TempDir(), "test_ftpd.crt")
+	keyPath := filepath.Join(os.TempDir(), "test_ftpd.key")
 	err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm)
 	err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm)
 	err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
+	keyPairs := []common.TLSKeyPair{
+		{
+			Cert: certPath,
+			Key:  keyPath,
+			ID:   common.DefaultTLSKeyPaidID,
+		},
+	}
+	certMgr, err = common.NewCertManager(keyPairs, configDir, "")
+	require.NoError(t, err)
 	common.Config.ProxyAllowed = nil
 	common.Config.ProxyAllowed = nil
 	c.CertificateFile = certPath
 	c.CertificateFile = certPath
 	c.CertificateKeyFile = keyPath
 	c.CertificateKeyFile = keyPath
@@ -550,12 +578,8 @@ func TestServerGetSettings(t *testing.T) {
 	_, ok := listener.(*proxyproto.Listener)
 	_, ok := listener.(*proxyproto.Listener)
 	assert.True(t, ok)
 	assert.True(t, ok)
 
 
-	err = os.Remove(certPath)
-	assert.NoError(t, err)
-	err = os.Remove(keyPath)
-	assert.NoError(t, err)
-
 	common.Config = oldConfig
 	common.Config = oldConfig
+	certMgr = oldMgr
 }
 }
 
 
 func TestUserInvalidParams(t *testing.T) {
 func TestUserInvalidParams(t *testing.T) {

+ 3 - 0
internal/ftpd/server.go

@@ -280,6 +280,9 @@ func (s *Server) buildTLSConfig() {
 		if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" {
 		if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" {
 			certID = s.binding.GetAddress()
 			certID = s.binding.GetAddress()
 		}
 		}
+		if !certMgr.HasCertificate(certID) {
+			return
+		}
 		s.tlsConfig = &tls.Config{
 		s.tlsConfig = &tls.Config{
 			GetCertificate: certMgr.GetCertificateFunc(certID),
 			GetCertificate: certMgr.GetCertificateFunc(certID),
 			MinVersion:     util.GetTLSVersion(s.binding.MinTLSVersion),
 			MinVersion:     util.GetTLSVersion(s.binding.MinTLSVersion),