Procházet zdrojové kódy

lib/syncthing: Clean up / refactor LoadOrGenerateCertificate() utility function. (#8025)

LoadOrGenerateCertificate() takes two file path arguments, but then
uses the locations package to determine the actual path.  Fix that
with a minimally invasive change, by using the arguments instead.
Factor out GenerateCertificate().

The only caller of this function is cmd/syncthing, which passes the
same values, so this is technically a no-op.

* lib/tlsutil: Make storing generated certificate optional.  Avoid
  temporary cert and key files in tests, keep cert in memory.
André Colomb před 4 roky
rodič
revize
ec8a748514

+ 2 - 5
cmd/syncthing/main.go

@@ -49,16 +49,13 @@ import (
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/svcutil"
 	"github.com/syncthing/syncthing/lib/syncthing"
-	"github.com/syncthing/syncthing/lib/tlsutil"
 	"github.com/syncthing/syncthing/lib/upgrade"
 
 	"github.com/pkg/errors"
 )
 
 const (
-	tlsDefaultCommonName   = "syncthing"
-	deviceCertLifetimeDays = 20 * 365
-	sigTerm                = syscall.Signal(15)
+	sigTerm = syscall.Signal(15)
 )
 
 const (
@@ -442,7 +439,7 @@ func generate(generateDir string, noDefaultFolder bool) error {
 	if err == nil {
 		l.Warnln("Key exists; will not overwrite.")
 	} else {
-		cert, err = tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
+		cert, err = syncthing.GenerateCertificate(certFile, keyFile)
 		if err != nil {
 			return errors.Wrap(err, "create certificate")
 		}

+ 3 - 9
lib/api/api_test.go

@@ -1209,15 +1209,9 @@ func TestPrefixMatch(t *testing.T) {
 }
 
 func TestShouldRegenerateCertificate(t *testing.T) {
-	dir, err := ioutil.TempDir("", "syncthing-test")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(dir)
-
 	// Self signed certificates expiring in less than a month are errored so we
 	// can regenerate in time.
-	crt, err := tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 29)
+	crt, err := tlsutil.NewCertificateInMemory("foo.example.com", 29)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -1226,7 +1220,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
 	}
 
 	// Certificates with at least 31 days of life left are fine.
-	crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 31)
+	crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 31)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -1236,7 +1230,7 @@ func TestShouldRegenerateCertificate(t *testing.T) {
 
 	if runtime.GOOS == "darwin" {
 		// Certificates with too long an expiry time are not allowed on macOS
-		crt, err = tlsutil.NewCertificate(filepath.Join(dir, "crt"), filepath.Join(dir, "key"), "foo.example.com", 1000)
+		crt, err = tlsutil.NewCertificateInMemory("foo.example.com", 1000)
 		if err != nil {
 			t.Fatal(err)
 		}

+ 1 - 15
lib/connections/connections_test.go

@@ -11,11 +11,9 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
-	"io/ioutil"
 	"math/rand"
 	"net"
 	"net/url"
-	"os"
 	"strings"
 	"testing"
 	"time"
@@ -470,21 +468,9 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte
 }
 
 func mustGetCert(b *testing.B) tls.Certificate {
-	f1, err := ioutil.TempFile("", "")
+	cert, err := tlsutil.NewCertificateInMemory("bench", 10)
 	if err != nil {
 		b.Fatal(err)
 	}
-	f1.Close()
-	f2, err := ioutil.TempFile("", "")
-	if err != nil {
-		b.Fatal(err)
-	}
-	f2.Close()
-	cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10)
-	if err != nil {
-		b.Fatal(err)
-	}
-	_ = os.Remove(f1.Name())
-	_ = os.Remove(f2.Name())
 	return cert
 }

+ 2 - 12
lib/discover/global_test.go

@@ -107,13 +107,8 @@ func TestGlobalOverHTTP(t *testing.T) {
 }
 
 func TestGlobalOverHTTPS(t *testing.T) {
-	dir, err := ioutil.TempDir("", "syncthing")
-	if err != nil {
-		t.Fatal(err)
-	}
-
 	// Generate a server certificate.
-	cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
+	cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -172,13 +167,8 @@ func TestGlobalOverHTTPS(t *testing.T) {
 }
 
 func TestGlobalAnnounce(t *testing.T) {
-	dir, err := ioutil.TempDir("", "syncthing")
-	if err != nil {
-		t.Fatal(err)
-	}
-
 	// Generate a server certificate.
-	cert, err := tlsutil.NewCertificate(dir+"/cert.pem", dir+"/key.pem", "syncthing", 30)
+	cert, err := tlsutil.NewCertificateInMemory("syncthing", 30)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 1 - 8
lib/syncthing/syncthing_test.go

@@ -9,7 +9,6 @@ package syncthing
 import (
 	"io/ioutil"
 	"os"
-	"path/filepath"
 	"testing"
 	"time"
 
@@ -57,13 +56,7 @@ func TestShortIDCheck(t *testing.T) {
 }
 
 func TestStartupFail(t *testing.T) {
-	tmpDir, err := ioutil.TempDir("", "syncthing-TestStartupFail-")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(tmpDir)
-
-	cert, err := tlsutil.NewCertificate(filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key"), "syncthing", 365)
+	cert, err := tlsutil.NewCertificateInMemory("syncthing", 365)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 7 - 11
lib/syncthing/utils.go

@@ -25,22 +25,18 @@ import (
 )
 
 func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
-	cert, err := tls.LoadX509KeyPair(
-		locations.Get(locations.CertFile),
-		locations.Get(locations.KeyFile),
-	)
+	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
 	if err != nil {
-		l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
-		return tlsutil.NewCertificate(
-			locations.Get(locations.CertFile),
-			locations.Get(locations.KeyFile),
-			tlsDefaultCommonName,
-			deviceCertLifetimeDays,
-		)
+		return GenerateCertificate(certFile, keyFile)
 	}
 	return cert, nil
 }
 
+func GenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
+	l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
+	return tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, deviceCertLifetimeDays)
+}
+
 func DefaultConfig(path string, myID protocol.DeviceID, evLogger events.Logger, noDefaultFolder bool) (config.Wrapper, error) {
 	newCfg, err := config.NewWithFreePorts(myID)
 	if err != nil {

+ 33 - 17
lib/tlsutil/tlsutil.go

@@ -86,11 +86,11 @@ func SecureDefaultWithTLS12() *tls.Config {
 	}
 }
 
-// NewCertificate generates and returns a new TLS certificate.
-func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls.Certificate, error) {
+// generateCertificate generates a PEM formatted key pair and self-signed certificate in memory.
+func generateCertificate(commonName string, lifetimeDays int) (*pem.Block, *pem.Block, error) {
 	priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
 	if err != nil {
-		return tls.Certificate{}, errors.Wrap(err, "generate key")
+		return nil, nil, errors.Wrap(err, "generate key")
 	}
 
 	notBefore := time.Now().Truncate(24 * time.Hour)
@@ -117,19 +117,33 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
 
 	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
 	if err != nil {
-		return tls.Certificate{}, errors.Wrap(err, "create cert")
+		return nil, nil, errors.Wrap(err, "create cert")
+	}
+
+	certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}
+	keyBlock, err := pemBlockForKey(priv)
+	if err != nil {
+		return nil, nil, errors.Wrap(err, "save key")
+	}
+
+	return certBlock, keyBlock, nil
+}
+
+// NewCertificate generates and returns a new TLS certificate, saved to the given PEM files.
+func NewCertificate(certFile, keyFile string, commonName string, lifetimeDays int) (tls.Certificate, error) {
+	certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
+	if err != nil {
+		return tls.Certificate{}, err
 	}
 
 	certOut, err := os.Create(certFile)
 	if err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save cert")
 	}
-	err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
-	if err != nil {
+	if err = pem.Encode(certOut, certBlock); err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save cert")
 	}
-	err = certOut.Close()
-	if err != nil {
+	if err = certOut.Close(); err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save cert")
 	}
 
@@ -137,22 +151,24 @@ func NewCertificate(certFile, keyFile, commonName string, lifetimeDays int) (tls
 	if err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save key")
 	}
-
-	block, err := pemBlockForKey(priv)
-	if err != nil {
+	if err = pem.Encode(keyOut, keyBlock); err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save key")
 	}
-
-	err = pem.Encode(keyOut, block)
-	if err != nil {
+	if err = keyOut.Close(); err != nil {
 		return tls.Certificate{}, errors.Wrap(err, "save key")
 	}
-	err = keyOut.Close()
+
+	return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
+}
+
+// NewCertificateInMemory generates and returns a new TLS certificate, kept only in memory.
+func NewCertificateInMemory(commonName string, lifetimeDays int) (tls.Certificate, error) {
+	certBlock, keyBlock, err := generateCertificate(commonName, lifetimeDays)
 	if err != nil {
-		return tls.Certificate{}, errors.Wrap(err, "save key")
+		return tls.Certificate{}, err
 	}
 
-	return tls.LoadX509KeyPair(certFile, keyFile)
+	return tls.X509KeyPair(pem.EncodeToMemory(certBlock), pem.EncodeToMemory(keyBlock))
 }
 
 type DowngradingListener struct {