Browse Source

chore(stdiscosrv): clean up s3 handling

Jakob Borg 1 year ago
parent
commit
6505e123bb

+ 1 - 1
cmd/stdiscosrv/apisrv_test.go

@@ -107,7 +107,7 @@ func addr(host string, port int) *net.TCPAddr {
 }
 
 func BenchmarkAPIRequests(b *testing.B) {
-	db := newInMemoryStore(b.TempDir(), 0)
+	db := newInMemoryStore(b.TempDir(), 0, nil)
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	go db.Serve(ctx)

+ 8 - 44
cmd/stdiscosrv/database.go

@@ -24,10 +24,6 @@ import (
 	"strings"
 	"time"
 
-	"github.com/aws/aws-sdk-go/aws"
-	"github.com/aws/aws-sdk-go/aws/session"
-	"github.com/aws/aws-sdk-go/service/s3"
-	"github.com/aws/aws-sdk-go/service/s3/s3manager"
 	"github.com/puzpuzpuz/xsync/v3"
 	"github.com/syncthing/syncthing/lib/protocol"
 )
@@ -52,25 +48,27 @@ type inMemoryStore struct {
 	m             *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
 	dir           string
 	flushInterval time.Duration
+	s3            *s3Copier
 	clock         clock
 }
 
-func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore {
+func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore {
 	s := &inMemoryStore{
 		m:             xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
 		dir:           dir,
 		flushInterval: flushInterval,
+		s3:            s3,
 		clock:         defaultClock{},
 	}
 	nr, err := s.read()
-	if os.IsNotExist(err) {
+	if os.IsNotExist(err) && s3 != nil {
 		// Try to read from AWS
 		fd, cerr := os.Create(path.Join(s.dir, "records.db"))
 		if cerr != nil {
 			log.Println("Error creating database file:", err)
 			return s
 		}
-		if err := s3Download(fd); err != nil {
+		if err := s3.downloadLatest(fd); err != nil {
 			log.Printf("Error reading database from S3: %v", err)
 		}
 		_ = fd.Close()
@@ -278,16 +276,15 @@ func (s *inMemoryStore) write() (err error) {
 		return err
 	}
 
-	if os.Getenv("PODINDEX") == "0" {
-		// Upload to S3
-		log.Println("Uploading database")
+	// Upload to S3
+	if s.s3 != nil {
 		fd, err = os.Open(dbf)
 		if err != nil {
 			log.Printf("Error uploading database to S3: %v", err)
 			return nil
 		}
 		defer fd.Close()
-		if err := s3Upload(fd); err != nil {
+		if err := s.s3.upload(fd); err != nil {
 			log.Printf("Error uploading database to S3: %v", err)
 		}
 		log.Println("Finished uploading database")
@@ -424,39 +421,6 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
 	return naddrs
 }
 
-func s3Upload(r io.Reader) error {
-	sess, err := session.NewSession(&aws.Config{
-		Region:   aws.String("fr-par"),
-		Endpoint: aws.String("s3.fr-par.scw.cloud"),
-	})
-	if err != nil {
-		return err
-	}
-	uploader := s3manager.NewUploader(sess)
-	_, err = uploader.Upload(&s3manager.UploadInput{
-		Bucket: aws.String("syncthing-discovery"),
-		Key:    aws.String("discovery.db"),
-		Body:   r,
-	})
-	return err
-}
-
-func s3Download(w io.WriterAt) error {
-	sess, err := session.NewSession(&aws.Config{
-		Region:   aws.String("fr-par"),
-		Endpoint: aws.String("s3.fr-par.scw.cloud"),
-	})
-	if err != nil {
-		return err
-	}
-	downloader := s3manager.NewDownloader(sess)
-	_, err = downloader.Download(w, &s3.GetObjectInput{
-		Bucket: aws.String("syncthing-discovery"),
-		Key:    aws.String("discovery.db"),
-	})
-	return err
-}
-
 func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
 	if c := cmp.Compare(d.Address, other.Address); c != 0 {
 		return c

+ 1 - 1
cmd/stdiscosrv/database_test.go

@@ -16,7 +16,7 @@ import (
 )
 
 func TestDatabaseGetSet(t *testing.T) {
-	db := newInMemoryStore(t.TempDir(), 0)
+	db := newInMemoryStore(t.TempDir(), 0, nil)
 	ctx, cancel := context.WithCancel(context.Background())
 	go db.Serve(ctx)
 	defer cancel()

+ 58 - 47
cmd/stdiscosrv/main.go

@@ -9,7 +9,6 @@ package main
 import (
 	"context"
 	"crypto/tls"
-	"flag"
 	"log"
 	"net"
 	"net/http"
@@ -21,6 +20,7 @@ import (
 
 	_ "net/http/pprof"
 
+	"github.com/alecthomas/kong"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	_ "github.com/syncthing/syncthing/lib/automaxprocs"
 	"github.com/syncthing/syncthing/lib/build"
@@ -58,52 +58,52 @@ const (
 
 var debug = false
 
-func main() {
-	var listen string
-	var dir string
-	var metricsListen string
-	var replicationListen string
-	var replicationPeers string
-	var certFile string
-	var keyFile string
-	var replCertFile string
-	var replKeyFile string
-	var useHTTP bool
-	var compression bool
-	var amqpAddress string
-	var flushInterval time.Duration
+type CLI struct {
+	Cert          string `group:"Listen" help:"Certificate file" default:"./cert.pem" env:"DISCOVERY_CERT_FILE"`
+	Key           string `group:"Listen" help:"Key file" default:"./key.pem" env:"DISCOVERY_KEY_FILE"`
+	HTTP          bool   `group:"Listen" help:"Listen on HTTP (behind an HTTPS proxy)" env:"DISCOVERY_HTTP"`
+	Compression   bool   `group:"Listen" help:"Enable GZIP compression of responses" env:"DISCOVERY_COMPRESSION"`
+	Listen        string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"`
+	MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"`
+
+	Replicate         []string `group:"Legacy replication" help:"Replication peers, id@address, comma separated" env:"DISCOVERY_REPLICATE"`
+	ReplicationListen string   `group:"Legacy replication" help:"Replication listen address" default:":19200" env:"DISCOVERY_REPLICATION_LISTEN"`
+	ReplicationCert   string   `group:"Legacy replication" help:"Certificate file for replication" env:"DISCOVERY_REPLICATION_CERT_FILE"`
+	ReplicationKey    string   `group:"Legacy replication" help:"Key file for replication" env:"DISCOVERY_REPLICATION_KEY_FILE"`
+
+	AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"`
+
+	DBDir           string        `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"`
+	DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"`
+
+	DBS3Endpoint    string `name:"db-s3-endpoint" group:"Database (S3 backup)" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"`
+	DBS3Region      string `name:"db-s3-region" group:"Database (S3 backup)" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"`
+	DBS3Bucket      string `name:"db-s3-bucket" group:"Database (S3 backup)" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"`
+	DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"`
+	DBS3SecretKey   string `name:"db-s3-secret-key" group:"Database (S3 backup)" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"`
+
+	Debug   bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"`
+	Version bool `short:"v" help:"Print version and exit"`
+}
 
+func main() {
 	log.SetOutput(os.Stdout)
-	// log.SetFlags(0)
-
-	flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file")
-	flag.StringVar(&keyFile, "key", "./key.pem", "Key file")
-	flag.StringVar(&dir, "db-dir", ".", "Database directory")
-	flag.BoolVar(&debug, "debug", false, "Print debug output")
-	flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)")
-	flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses")
-	flag.StringVar(&listen, "listen", ":8443", "Listen address")
-	flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address")
-	flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated")
-	flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address")
-	flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication")
-	flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication")
-	flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker")
-	flag.DurationVar(&flushInterval, "flush-interval", 5*time.Minute, "Interval between database flushes")
-	showVersion := flag.Bool("version", false, "Show version")
-	flag.Parse()
+
+	var cli CLI
+	kong.Parse(&cli)
+	debug = cli.Debug
 
 	log.Println(build.LongVersionFor("stdiscosrv"))
-	if *showVersion {
+	if cli.Version {
 		return
 	}
 
 	buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1)
 
-	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+	cert, err := tls.LoadX509KeyPair(cli.Cert, cli.Key)
 	if os.IsNotExist(err) {
 		log.Println("Failed to load keypair. Generating one, this might take a while...")
-		cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 20*365)
+		cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365)
 		if err != nil {
 			log.Fatalln("Failed to generate X509 key pair:", err)
 		}
@@ -114,8 +114,8 @@ func main() {
 	log.Println("Server device ID is", devID)
 
 	replCert := cert
-	if replCertFile != "" && replKeyFile != "" {
-		replCert, err = tls.LoadX509KeyPair(replCertFile, replKeyFile)
+	if cli.ReplicationCert != "" && cli.ReplicationKey != "" {
+		replCert, err = tls.LoadX509KeyPair(cli.ReplicationCert, cli.ReplicationKey)
 		if err != nil {
 			log.Fatalln("Failed to load replication keypair:", err)
 		}
@@ -126,8 +126,7 @@ func main() {
 	// Parse the replication specs, if any.
 	var allowedReplicationPeers []protocol.DeviceID
 	var replicationDestinations []string
-	parts := strings.Split(replicationPeers, ",")
-	for _, part := range parts {
+	for _, part := range cli.Replicate {
 		if part == "" {
 			continue
 		}
@@ -165,10 +164,22 @@ func main() {
 	// Root of the service tree.
 	main := suture.New("main", suture.Spec{
 		PassThroughPanics: true,
+		Timeout:           2 * time.Minute,
 	})
 
+	// If configured, use S3 for database backups.
+	var s3c *s3Copier
+	if cli.DBS3Endpoint != "" {
+		hostname, err := os.Hostname()
+		if err != nil {
+			log.Fatalf("Failed to get hostname: %v", err)
+		}
+		key := hostname + ".db"
+		s3c = newS3Copier(cli.DBS3Endpoint, cli.DBS3Region, cli.DBS3Bucket, key, cli.DBS3AccessKeyID, cli.DBS3SecretKey)
+	}
+
 	// Start the database.
-	db := newInMemoryStore(dir, flushInterval)
+	db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c)
 	main.Add(db)
 
 	// Start any replication senders.
@@ -181,28 +192,28 @@ func main() {
 
 	// If we have replication configured, start the replication listener.
 	if len(allowedReplicationPeers) > 0 {
-		rl := newReplicationListener(replicationListen, replCert, allowedReplicationPeers, db)
+		rl := newReplicationListener(cli.ReplicationListen, replCert, allowedReplicationPeers, db)
 		main.Add(rl)
 	}
 
 	// If we have an AMQP broker, start that
-	if amqpAddress != "" {
+	if cli.AMQPAddress != "" {
 		clientID := rand.String(10)
-		kr := newAMQPReplicator(amqpAddress, clientID, db)
+		kr := newAMQPReplicator(cli.AMQPAddress, clientID, db)
 		repl = append(repl, kr)
 		main.Add(kr)
 	}
 
 	// Start the main API server.
-	qs := newAPISrv(listen, cert, db, repl, useHTTP, compression)
+	qs := newAPISrv(cli.Listen, cert, db, repl, cli.HTTP, cli.Compression)
 	main.Add(qs)
 
 	// If we have a metrics port configured, start a metrics handler.
-	if metricsListen != "" {
+	if cli.MetricsListen != "" {
 		go func() {
 			mux := http.NewServeMux()
 			mux.Handle("/metrics", promhttp.Handler())
-			log.Fatal(http.ListenAndServe(metricsListen, mux))
+			log.Fatal(http.ListenAndServe(cli.MetricsListen, mux))
 		}()
 	}
 

+ 97 - 0
cmd/stdiscosrv/s3.go

@@ -0,0 +1,97 @@
+// Copyright (C) 2024 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package main
+
+import (
+	"io"
+	"log"
+	"time"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/credentials"
+	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/aws/aws-sdk-go/service/s3/s3manager"
+)
+
+type s3Copier struct {
+	endpoint    string
+	region      string
+	bucket      string
+	key         string
+	accessKeyID string
+	secretKey   string
+}
+
+func newS3Copier(endpoint, region, bucket, key, accessKeyID, secretKey string) *s3Copier {
+	return &s3Copier{
+		endpoint:    endpoint,
+		region:      region,
+		bucket:      bucket,
+		key:         key,
+		accessKeyID: accessKeyID,
+		secretKey:   secretKey,
+	}
+}
+
+func (s *s3Copier) upload(r io.Reader) error {
+	sess, err := session.NewSession(&aws.Config{
+		Region:      aws.String(s.region),
+		Endpoint:    aws.String(s.endpoint),
+		Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
+	})
+	if err != nil {
+		return err
+	}
+
+	uploader := s3manager.NewUploader(sess)
+	_, err = uploader.Upload(&s3manager.UploadInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(s.key),
+		Body:   r,
+	})
+	return err
+}
+
+func (s *s3Copier) downloadLatest(w io.WriterAt) error {
+	sess, err := session.NewSession(&aws.Config{
+		Region:      aws.String(s.region),
+		Endpoint:    aws.String(s.endpoint),
+		Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
+	})
+	if err != nil {
+		return err
+	}
+
+	svc := s3.New(sess)
+	resp, err := svc.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: aws.String(s.bucket)})
+	if err != nil {
+		return err
+	}
+
+	var lastKey string
+	var lastModified time.Time
+	var lastSize int64
+	for _, item := range resp.Contents {
+		if item.LastModified.After(lastModified) && *item.Size > lastSize {
+			lastKey = *item.Key
+			lastModified = *item.LastModified
+			lastSize = *item.Size
+		} else if lastModified.Sub(*item.LastModified) < 5*time.Minute && *item.Size > lastSize {
+			lastKey = *item.Key
+			lastSize = *item.Size
+		}
+	}
+
+	log.Println("Downloading database from", lastKey)
+	downloader := s3manager.NewDownloader(sess)
+	_, err = downloader.Download(w, &s3.GetObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(lastKey),
+	})
+	return err
+}