浏览代码

Merge branch 'infrastructure'

* infrastructure:
  chore(stdiscosrv): hide internal/undocumented flags
  chore(stdiscosrv): remove legacy replication
  chore(stdiscosrv): clean up s3 handling
  chore(stdiscosrv): less garbage in statistics
  chore(stdiscosrv): improve expire, logging
  chore(stdiscosrv): sched in loop
  chore(stdiscosrv): database writing logging
  chore(stdiscosrv): use order-preserving expire
  chore(stdiscosrv): simplify sorting
  chore(stdiscosrv): reduce allocations in cert handling
  chore(stdiscosrv): reduce unnecessary allocations in merge
  feat(stdiscosrv): enable HTTP profiler
  feat(discosrv): in-memory storage with S3 backing
  feat(stdiscosrv): make compression optional (and faster)
Jakob Borg 1 年之前
父节点
当前提交
a156e88eef

+ 15 - 4
cmd/stdiscosrv/amqp.go

@@ -10,8 +10,10 @@ import (
 	"context"
 	"fmt"
 	"io"
+	"log"
 
 	amqp "github.com/rabbitmq/amqp091-go"
+	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/thejerf/suture/v4"
 )
 
@@ -49,7 +51,7 @@ func newAMQPReplicator(broker, clientID string, db database) *amqpReplicator {
 	}
 }
 
-func (s *amqpReplicator) send(key string, ps []DatabaseAddress, seen int64) {
+func (s *amqpReplicator) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) {
 	s.sender.send(key, ps, seen)
 }
 
@@ -109,9 +111,9 @@ func (s *amqpSender) String() string {
 	return fmt.Sprintf("amqpSender(%q)", s.broker)
 }
 
-func (s *amqpSender) send(key string, ps []DatabaseAddress, seen int64) {
+func (s *amqpSender) send(key *protocol.DeviceID, ps []DatabaseAddress, seen int64) {
 	item := ReplicationRecord{
-		Key:       key,
+		Key:       key[:],
 		Addresses: ps,
 		Seen:      seen,
 	}
@@ -161,8 +163,17 @@ func (s *amqpReceiver) Serve(ctx context.Context) error {
 				replicationRecvsTotal.WithLabelValues("error").Inc()
 				return fmt.Errorf("replication unmarshal: %w", err)
 			}
+			id, err := protocol.DeviceIDFromBytes(rec.Key)
+			if err != nil {
+				id, err = protocol.DeviceIDFromString(string(rec.Key))
+			}
+			if err != nil {
+				log.Println("Replication device ID:", err)
+				replicationRecvsTotal.WithLabelValues("error").Inc()
+				continue
+			}
 
-			if err := s.db.merge(rec.Key, rec.Addresses, rec.Seen); err != nil {
+			if err := s.db.merge(&id, rec.Addresses, rec.Seen); err != nil {
 				return fmt.Errorf("replication database merge: %w", err)
 			}
 

+ 116 - 59
cmd/stdiscosrv/apisrv.go

@@ -22,7 +22,7 @@ import (
 	"net"
 	"net/http"
 	"net/url"
-	"sort"
+	"slices"
 	"strconv"
 	"strings"
 	"sync"
@@ -45,10 +45,14 @@ type apiSrv struct {
 	listener       net.Listener
 	repl           replicator // optional
 	useHTTP        bool
-	missesIncrease int
+	compression    bool
+	gzipWriters    sync.Pool
+	seenTracker    *retryAfterTracker
+	notSeenTracker *retryAfterTracker
+}
 
-	mapsMut sync.Mutex
-	misses  map[string]int32
+type replicator interface {
+	send(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64)
 }
 
 type requestID int64
@@ -61,19 +65,30 @@ type contextKey int
 
 const idKey contextKey = iota
 
-func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP bool, missesIncrease int) *apiSrv {
+func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool) *apiSrv {
 	return &apiSrv{
-		addr:           addr,
-		cert:           cert,
-		db:             db,
-		repl:           repl,
-		useHTTP:        useHTTP,
-		misses:         make(map[string]int32),
-		missesIncrease: missesIncrease,
+		addr:        addr,
+		cert:        cert,
+		db:          db,
+		repl:        repl,
+		useHTTP:     useHTTP,
+		compression: compression,
+		seenTracker: &retryAfterTracker{
+			name:         "seenTracker",
+			bucketStarts: time.Now(),
+			desiredRate:  250,
+			currentDelay: notFoundRetryUnknownMinSeconds,
+		},
+		notSeenTracker: &retryAfterTracker{
+			name:         "notSeenTracker",
+			bucketStarts: time.Now(),
+			desiredRate:  250,
+			currentDelay: notFoundRetryUnknownMaxSeconds / 2,
+		},
 	}
 }
 
-func (s *apiSrv) Serve(_ context.Context) error {
+func (s *apiSrv) Serve(ctx context.Context) error {
 	if s.useHTTP {
 		listener, err := net.Listen("tcp", s.addr)
 		if err != nil {
@@ -107,6 +122,11 @@ func (s *apiSrv) Serve(_ context.Context) error {
 		ErrorLog:       log.New(io.Discard, "", 0),
 	}
 
+	go func() {
+		<-ctx.Done()
+		srv.Shutdown(context.Background())
+	}()
+
 	err := srv.Serve(s.listener)
 	if err != nil {
 		log.Println("Serve:", err)
@@ -183,8 +203,7 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	key := deviceID.String()
-	rec, err := s.db.get(key)
+	rec, err := s.db.get(&deviceID)
 	if err != nil {
 		// some sort of internal error
 		lookupRequestsTotal.WithLabelValues("internal_error").Inc()
@@ -194,27 +213,14 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
 	}
 
 	if len(rec.Addresses) == 0 {
-		lookupRequestsTotal.WithLabelValues("not_found").Inc()
-
-		s.mapsMut.Lock()
-		misses := s.misses[key]
-		if misses < rec.Misses {
-			misses = rec.Misses
-		}
-		misses += int32(s.missesIncrease)
-		s.misses[key] = misses
-		s.mapsMut.Unlock()
-
-		if misses >= notFoundMissesWriteInterval {
-			rec.Misses = misses
-			rec.Missed = time.Now().UnixNano()
-			rec.Addresses = nil
-			// rec.Seen retained from get
-			s.db.put(key, rec)
+		var afterS int
+		if rec.Seen == 0 {
+			afterS = s.notSeenTracker.retryAfterS()
+			lookupRequestsTotal.WithLabelValues("not_found_ever").Inc()
+		} else {
+			afterS = s.seenTracker.retryAfterS()
+			lookupRequestsTotal.WithLabelValues("not_found_recent").Inc()
 		}
-
-		afterS := notFoundRetryAfterSeconds(int(misses))
-		retryAfterHistogram.Observe(float64(afterS))
 		w.Header().Set("Retry-After", strconv.Itoa(afterS))
 		http.Error(w, "Not Found", http.StatusNotFound)
 		return
@@ -226,10 +232,16 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
 	var bw io.Writer = w
 
 	// Use compression if the client asks for it
-	if strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
+	if s.compression && strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
+		gw, ok := s.gzipWriters.Get().(*gzip.Writer)
+		if ok {
+			gw.Reset(w)
+		} else {
+			gw = gzip.NewWriter(w)
+		}
 		w.Header().Set("Content-Encoding", "gzip")
-		gw := gzip.NewWriter(bw)
 		defer gw.Close()
+		defer s.gzipWriters.Put(gw)
 		bw = gw
 	}
 
@@ -292,7 +304,6 @@ func (s *apiSrv) Stop() {
 }
 
 func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) error {
-	key := deviceID.String()
 	now := time.Now()
 	expire := now.Add(addressExpiryTime).UnixNano()
 
@@ -304,13 +315,13 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string)
 
 	// The address slice must always be sorted for database merges to work
 	// properly.
-	sort.Sort(databaseAddressOrder(dbAddrs))
+	slices.SortFunc(dbAddrs, DatabaseAddress.Cmp)
 
 	seen := now.UnixNano()
 	if s.repl != nil {
-		s.repl.send(key, dbAddrs, seen)
+		s.repl.send(&deviceID, dbAddrs, seen)
 	}
-	return s.db.merge(key, dbAddrs, seen)
+	return s.db.merge(&deviceID, dbAddrs, seen)
 }
 
 func handlePing(w http.ResponseWriter, _ *http.Request) {
@@ -360,7 +371,7 @@ func certificateBytes(req *http.Request) ([]byte, error) {
 		}
 
 		bs = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: hdr})
-	} else if hdr := req.Header.Get("X-Forwarded-Tls-Client-Cert"); hdr != "" {
+	} else if cert := req.Header.Get("X-Forwarded-Tls-Client-Cert"); cert != "" {
 		// Traefik 2 passtlsclientcert
 		//
 		// The certificate is in PEM format, maybe with URL encoding
@@ -368,19 +379,36 @@ func certificateBytes(req *http.Request) ([]byte, error) {
 		// statements. We need to decode, reinstate the newlines every 64
 		// character and add statements for the PEM decoder
 
-		if strings.Contains(hdr, "%") {
-			if unesc, err := url.QueryUnescape(hdr); err == nil {
-				hdr = unesc
+		if strings.Contains(cert, "%") {
+			if unesc, err := url.QueryUnescape(cert); err == nil {
+				cert = unesc
 			}
 		}
 
-		for i := 64; i < len(hdr); i += 65 {
-			hdr = hdr[:i] + "\n" + hdr[i:]
+		const (
+			header = "-----BEGIN CERTIFICATE-----"
+			footer = "-----END CERTIFICATE-----"
+		)
+
+		var b bytes.Buffer
+		b.Grow(len(header) + 1 + len(cert) + len(cert)/64 + 1 + len(footer) + 1)
+
+		b.WriteString(header)
+		b.WriteByte('\n')
+
+		for i := 0; i < len(cert); i += 64 {
+			end := i + 64
+			if end > len(cert) {
+				end = len(cert)
+			}
+			b.WriteString(cert[i:end])
+			b.WriteByte('\n')
 		}
 
-		hdr = "-----BEGIN CERTIFICATE-----\n" + hdr
-		hdr += "\n-----END CERTIFICATE-----\n"
-		bs = []byte(hdr)
+		b.WriteString(footer)
+		b.WriteByte('\n')
+
+		bs = b.Bytes()
 	}
 
 	if bs == nil {
@@ -494,15 +522,44 @@ func errorRetryAfterString() string {
 	return strconv.Itoa(errorRetryAfterSeconds + rand.Intn(errorRetryFuzzSeconds))
 }
 
-func notFoundRetryAfterSeconds(misses int) int {
-	retryAfterS := notFoundRetryMinSeconds + notFoundRetryIncSeconds*misses
-	if retryAfterS > notFoundRetryMaxSeconds {
-		retryAfterS = notFoundRetryMaxSeconds
-	}
-	retryAfterS += rand.Intn(notFoundRetryFuzzSeconds)
-	return retryAfterS
-}
-
 func reannounceAfterString() string {
 	return strconv.Itoa(reannounceAfterSeconds + rand.Intn(reannounzeFuzzSeconds))
 }
+
+type retryAfterTracker struct {
+	name        string
+	desiredRate float64 // requests per second
+
+	mut          sync.Mutex
+	lastCount    int       // requests in the last bucket
+	curCount     int       // requests in the current bucket
+	bucketStarts time.Time // start of the current bucket
+	currentDelay int       // current delay in seconds
+}
+
+func (t *retryAfterTracker) retryAfterS() int {
+	now := time.Now()
+	t.mut.Lock()
+	if durS := now.Sub(t.bucketStarts).Seconds(); durS > float64(t.currentDelay) {
+		t.bucketStarts = now
+		t.lastCount = t.curCount
+		lastRate := float64(t.lastCount) / durS
+
+		switch {
+		case t.currentDelay > notFoundRetryUnknownMinSeconds &&
+			lastRate < 0.75*t.desiredRate:
+			t.currentDelay = max(8*t.currentDelay/10, notFoundRetryUnknownMinSeconds)
+		case t.currentDelay < notFoundRetryUnknownMaxSeconds &&
+			lastRate > 1.25*t.desiredRate:
+			t.currentDelay = min(3*t.currentDelay/2, notFoundRetryUnknownMaxSeconds)
+		}
+
+		t.curCount = 0
+	}
+	if t.curCount == 0 {
+		retryAfterLevel.WithLabelValues(t.name).Set(float64(t.currentDelay))
+	}
+	t.curCount++
+	t.mut.Unlock()
+	return t.currentDelay + rand.Intn(t.currentDelay/4)
+}

+ 87 - 0
cmd/stdiscosrv/apisrv_test.go

@@ -7,9 +7,20 @@
 package main
 
 import (
+	"context"
+	"crypto/tls"
 	"fmt"
+	"io"
 	"net"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"regexp"
+	"strings"
 	"testing"
+
+	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/tlsutil"
 )
 
 func TestFixupAddresses(t *testing.T) {
@@ -94,3 +105,79 @@ func addr(host string, port int) *net.TCPAddr {
 		Port: port,
 	}
 }
+
+func BenchmarkAPIRequests(b *testing.B) {
+	db := newInMemoryStore(b.TempDir(), 0, nil)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	go db.Serve(ctx)
+	api := newAPISrv("127.0.0.1:0", tls.Certificate{}, db, nil, true, true)
+	srv := httptest.NewServer(http.HandlerFunc(api.handler))
+
+	kf := b.TempDir() + "/cert"
+	crt, err := tlsutil.NewCertificate(kf+".crt", kf+".key", "localhost", 7)
+	if err != nil {
+		b.Fatal(err)
+	}
+	certBs, err := os.ReadFile(kf + ".crt")
+	if err != nil {
+		b.Fatal(err)
+	}
+	certBs = regexp.MustCompile(`---[^\n]+---\n`).ReplaceAll(certBs, nil)
+	certString := string(strings.ReplaceAll(string(certBs), "\n", " "))
+
+	devID := protocol.NewDeviceID(crt.Certificate[0])
+	devIDString := devID.String()
+
+	b.Run("Announce", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"addresses":["tcp://10.10.10.10:42000"]}`))
+			req.Header.Set("X-Forwarded-Tls-Client-Cert", certString)
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusNoContent {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+
+	b.Run("Lookup", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodGet, url, nil)
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			io.Copy(io.Discard, resp.Body)
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusOK {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+
+	b.Run("LookupNoCompression", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodGet, url, nil)
+			req.Header.Set("Accept-Encoding", "identity") // disable compression
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			io.Copy(io.Discard, resp.Body)
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusOK {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+}

+ 264 - 248
cmd/stdiscosrv/database.go

@@ -10,17 +10,22 @@
 package main
 
 import (
+	"bufio"
+	"cmp"
 	"context"
+	"encoding/binary"
+	"errors"
+	"io"
 	"log"
-	"net"
-	"net/url"
-	"sort"
+	"os"
+	"path"
+	"runtime"
+	"slices"
+	"strings"
 	"time"
 
-	"github.com/syncthing/syncthing/lib/sliceutil"
-	"github.com/syndtr/goleveldb/leveldb"
-	"github.com/syndtr/goleveldb/leveldb/storage"
-	"github.com/syndtr/goleveldb/leveldb/util"
+	"github.com/puzpuzpuz/xsync/v3"
+	"github.com/syncthing/syncthing/lib/protocol"
 )
 
 type clock interface {
@@ -34,270 +39,305 @@ func (defaultClock) Now() time.Time {
 }
 
 type database interface {
-	put(key string, rec DatabaseRecord) error
-	merge(key string, addrs []DatabaseAddress, seen int64) error
-	get(key string) (DatabaseRecord, error)
+	put(key *protocol.DeviceID, rec DatabaseRecord) error
+	merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error
+	get(key *protocol.DeviceID) (DatabaseRecord, error)
 }
 
-type levelDBStore struct {
-	db         *leveldb.DB
-	inbox      chan func()
-	clock      clock
-	marshalBuf []byte
+type inMemoryStore struct {
+	m             *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
+	dir           string
+	flushInterval time.Duration
+	s3            *s3Copier
+	clock         clock
 }
 
-func newLevelDBStore(dir string) (*levelDBStore, error) {
-	db, err := leveldb.OpenFile(dir, levelDBOptions)
-	if err != nil {
-		return nil, err
+func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore {
+	s := &inMemoryStore{
+		m:             xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
+		dir:           dir,
+		flushInterval: flushInterval,
+		s3:            s3,
+		clock:         defaultClock{},
 	}
-	return &levelDBStore{
-		db:    db,
-		inbox: make(chan func(), 16),
-		clock: defaultClock{},
-	}, nil
-}
-
-func newMemoryLevelDBStore() (*levelDBStore, error) {
-	db, err := leveldb.Open(storage.NewMemStorage(), nil)
-	if err != nil {
-		return nil, err
-	}
-	return &levelDBStore{
-		db:    db,
-		inbox: make(chan func(), 16),
-		clock: defaultClock{},
-	}, nil
-}
-
-func (s *levelDBStore) put(key string, rec DatabaseRecord) error {
-	t0 := time.Now()
-	defer func() {
-		databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
-	}()
-
-	rc := make(chan error)
-
-	s.inbox <- func() {
-		size := rec.Size()
-		if len(s.marshalBuf) < size {
-			s.marshalBuf = make([]byte, size)
+	nr, err := s.read()
+	if os.IsNotExist(err) && s3 != nil {
+		// Try to read from AWS
+		fd, cerr := os.Create(path.Join(s.dir, "records.db"))
+		if cerr != nil {
+			log.Println("Error creating database file:", err)
+			return s
+		}
+		if err := s3.downloadLatest(fd); err != nil {
+			log.Printf("Error reading database from S3: %v", err)
 		}
-		n, _ := rec.MarshalTo(s.marshalBuf)
-		rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil)
+		_ = fd.Close()
+		nr, err = s.read()
 	}
-
-	err := <-rc
 	if err != nil {
-		databaseOperations.WithLabelValues(dbOpPut, dbResError).Inc()
-	} else {
-		databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
+		log.Println("Error reading database:", err)
 	}
+	log.Printf("Read %d records from database", nr)
+	s.calculateStatistics()
+	return s
+}
 
-	return err
+func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error {
+	t0 := time.Now()
+	s.m.Store(*key, rec)
+	databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
+	databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
+	return nil
 }
 
-func (s *levelDBStore) merge(key string, addrs []DatabaseAddress, seen int64) error {
+func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error {
 	t0 := time.Now()
-	defer func() {
-		databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
-	}()
 
-	rc := make(chan error)
 	newRec := DatabaseRecord{
 		Addresses: addrs,
 		Seen:      seen,
 	}
 
-	s.inbox <- func() {
-		// grab the existing record
-		oldRec, err := s.get(key)
-		if err != nil {
-			// "not found" is not an error from get, so this is serious
-			// stuff only
-			rc <- err
-			return
-		}
-		newRec = merge(newRec, oldRec)
-
-		// We replicate s.put() functionality here ourselves instead of
-		// calling it because we want to serialize our get above together
-		// with the put in the same function.
-		size := newRec.Size()
-		if len(s.marshalBuf) < size {
-			s.marshalBuf = make([]byte, size)
-		}
-		n, _ := newRec.MarshalTo(s.marshalBuf)
-		rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil)
-	}
+	oldRec, _ := s.m.Load(*key)
+	newRec = merge(newRec, oldRec)
+	s.m.Store(*key, newRec)
 
-	err := <-rc
-	if err != nil {
-		databaseOperations.WithLabelValues(dbOpMerge, dbResError).Inc()
-	} else {
-		databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
-	}
+	databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
+	databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
 
-	return err
+	return nil
 }
 
-func (s *levelDBStore) get(key string) (DatabaseRecord, error) {
+func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
 	t0 := time.Now()
 	defer func() {
 		databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
 	}()
 
-	keyBs := []byte(key)
-	val, err := s.db.Get(keyBs, nil)
-	if err == leveldb.ErrNotFound {
+	rec, ok := s.m.Load(*key)
+	if !ok {
 		databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
 		return DatabaseRecord{}, nil
 	}
-	if err != nil {
-		databaseOperations.WithLabelValues(dbOpGet, dbResError).Inc()
-		return DatabaseRecord{}, err
-	}
 
-	var rec DatabaseRecord
-
-	if err := rec.Unmarshal(val); err != nil {
-		databaseOperations.WithLabelValues(dbOpGet, dbResUnmarshalError).Inc()
-		return DatabaseRecord{}, nil
-	}
-
-	rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano())
+	rec.Addresses = expire(rec.Addresses, s.clock.Now())
 	databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
 	return rec, nil
 }
 
-func (s *levelDBStore) Serve(ctx context.Context) error {
-	t := time.NewTimer(0)
+func (s *inMemoryStore) Serve(ctx context.Context) error {
+	t := time.NewTimer(s.flushInterval)
 	defer t.Stop()
-	defer s.db.Close()
 
-	// Start the statistics serve routine. It will exit with us when
-	// statisticsTrigger is closed.
-	statisticsTrigger := make(chan struct{})
-	statisticsDone := make(chan struct{})
-	go s.statisticsServe(statisticsTrigger, statisticsDone)
+	if s.flushInterval <= 0 {
+		t.Stop()
+	}
 
 loop:
 	for {
 		select {
-		case fn := <-s.inbox:
-			// Run function in serialized order.
-			fn()
-
 		case <-t.C:
-			// Trigger the statistics routine to do its thing in the
-			// background.
-			statisticsTrigger <- struct{}{}
-
-		case <-statisticsDone:
-			// The statistics routine is done with one iteratation, schedule
-			// the next.
-			t.Reset(databaseStatisticsInterval)
+			log.Println("Calculating statistics")
+			s.calculateStatistics()
+			log.Println("Flushing database")
+			if err := s.write(); err != nil {
+				log.Println("Error writing database:", err)
+			}
+			log.Println("Finished flushing database")
+			t.Reset(s.flushInterval)
 
 		case <-ctx.Done():
 			// We're done.
-			close(statisticsTrigger)
 			break loop
 		}
 	}
 
-	// Also wait for statisticsServe to return
-	<-statisticsDone
-
-	return nil
+	return s.write()
 }
 
-func (s *levelDBStore) statisticsServe(trigger <-chan struct{}, done chan<- struct{}) {
-	defer close(done)
-
-	for range trigger {
-		t0 := time.Now()
-		nowNanos := t0.UnixNano()
-		cutoff24h := t0.Add(-24 * time.Hour).UnixNano()
-		cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano()
-		cutoff2Mon := t0.Add(-60 * 24 * time.Hour).UnixNano()
-		current, currentIPv4, currentIPv6, last24h, last1w, inactive, errors := 0, 0, 0, 0, 0, 0, 0
-
-		iter := s.db.NewIterator(&util.Range{}, nil)
-		for iter.Next() {
-			// Attempt to unmarshal the record and count the
-			// failure if there's something wrong with it.
-			var rec DatabaseRecord
-			if err := rec.Unmarshal(iter.Value()); err != nil {
-				errors++
-				continue
-			}
+func (s *inMemoryStore) calculateStatistics() {
+	now := s.clock.Now()
+	cutoff24h := now.Add(-24 * time.Hour).UnixNano()
+	cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
+	current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0
 
-			// If there are addresses that have not expired it's a current
-			// record, otherwise account it based on when it was last seen
-			// (last 24 hours or last week) or finally as inactice.
-			addrs := expire(rec.Addresses, nowNanos)
-			switch {
-			case len(addrs) > 0:
-				current++
-				seenIPv4, seenIPv6 := false, false
-				for _, addr := range addrs {
-					uri, err := url.Parse(addr.Address)
-					if err != nil {
-						continue
-					}
-					host, _, err := net.SplitHostPort(uri.Host)
-					if err != nil {
-						continue
-					}
-					if ip := net.ParseIP(host); ip != nil && ip.To4() != nil {
-						seenIPv4 = true
-					} else if ip != nil {
-						seenIPv6 = true
-					}
-					if seenIPv4 && seenIPv6 {
-						break
-					}
-				}
-				if seenIPv4 {
-					currentIPv4++
-				}
-				if seenIPv6 {
-					currentIPv6++
-				}
-			case rec.Seen > cutoff24h:
-				last24h++
-			case rec.Seen > cutoff1w:
-				last1w++
-			case rec.Seen > cutoff2Mon:
-				inactive++
-			case rec.Missed < cutoff2Mon:
-				// It hasn't been seen lately and we haven't recorded
-				// someone asking for this device in a long time either;
-				// delete the record.
-				if err := s.db.Delete(iter.Key(), nil); err != nil {
-					databaseOperations.WithLabelValues(dbOpDelete, dbResError).Inc()
+	n := 0
+	s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
+		if n%1000 == 0 {
+			runtime.Gosched()
+		}
+		n++
+
+		addresses := expire(rec.Addresses, now)
+		switch {
+		case len(addresses) > 0:
+			current++
+			seenIPv4, seenIPv6 := false, false
+			for _, addr := range rec.Addresses {
+				if strings.Contains(addr.Address, "[") {
+					seenIPv6 = true
 				} else {
-					databaseOperations.WithLabelValues(dbOpDelete, dbResSuccess).Inc()
+					seenIPv4 = true
+				}
+				if seenIPv4 && seenIPv6 {
+					break
 				}
-			default:
-				inactive++
 			}
+			if seenIPv4 {
+				currentIPv4++
+			}
+			if seenIPv6 {
+				currentIPv6++
+			}
+		case rec.Seen > cutoff24h:
+			last24h++
+		case rec.Seen > cutoff1w:
+			last1w++
+		default:
+			// drop the record if it's older than a week
+			s.m.Delete(key)
 		}
+		return true
+	})
+
+	databaseKeys.WithLabelValues("current").Set(float64(current))
+	databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
+	databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
+	databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
+	databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
+	databaseStatisticsSeconds.Set(time.Since(now).Seconds())
+}
 
-		iter.Release()
+func (s *inMemoryStore) write() (err error) {
+	t0 := time.Now()
+	defer func() {
+		if err == nil {
+			databaseWriteSeconds.Set(time.Since(t0).Seconds())
+			databaseLastWritten.Set(float64(t0.Unix()))
+		}
+	}()
+
+	dbf := path.Join(s.dir, "records.db")
+	fd, err := os.Create(dbf + ".tmp")
+	if err != nil {
+		return err
+	}
+	bw := bufio.NewWriter(fd)
+
+	var buf []byte
+	var rangeErr error
+	now := s.clock.Now()
+	cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
+	n := 0
+	s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
+		if n%1000 == 0 {
+			runtime.Gosched()
+		}
+		n++
+
+		if value.Seen < cutoff1w {
+			// drop the record if it's older than a week
+			return true
+		}
+		rec := ReplicationRecord{
+			Key:       key[:],
+			Addresses: value.Addresses,
+			Seen:      value.Seen,
+		}
+		s := rec.Size()
+		if s+4 > len(buf) {
+			buf = make([]byte, s+4)
+		}
+		n, err := rec.MarshalTo(buf[4:])
+		if err != nil {
+			rangeErr = err
+			return false
+		}
+		binary.BigEndian.PutUint32(buf, uint32(n))
+		if _, err := bw.Write(buf[:n+4]); err != nil {
+			rangeErr = err
+			return false
+		}
+		return true
+	})
+	if rangeErr != nil {
+		_ = fd.Close()
+		return rangeErr
+	}
 
-		databaseKeys.WithLabelValues("current").Set(float64(current))
-		databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
-		databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
-		databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
-		databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
-		databaseKeys.WithLabelValues("inactive").Set(float64(inactive))
-		databaseKeys.WithLabelValues("error").Set(float64(errors))
-		databaseStatisticsSeconds.Set(time.Since(t0).Seconds())
+	if err := bw.Flush(); err != nil {
+		_ = fd.Close
+		return err
+	}
+	if err := fd.Close(); err != nil {
+		return err
+	}
+	if err := os.Rename(dbf+".tmp", dbf); err != nil {
+		return err
+	}
 
-		// Signal that we are done and can be scheduled again.
-		done <- struct{}{}
+	// Upload to S3
+	if s.s3 != nil {
+		fd, err = os.Open(dbf)
+		if err != nil {
+			log.Printf("Error uploading database to S3: %v", err)
+			return nil
+		}
+		defer fd.Close()
+		if err := s.s3.upload(fd); err != nil {
+			log.Printf("Error uploading database to S3: %v", err)
+		}
+		log.Println("Finished uploading database")
 	}
+
+	return nil
+}
+
+func (s *inMemoryStore) read() (int, error) {
+	fd, err := os.Open(path.Join(s.dir, "records.db"))
+	if err != nil {
+		return 0, err
+	}
+	defer fd.Close()
+
+	br := bufio.NewReader(fd)
+	var buf []byte
+	nr := 0
+	for {
+		var n uint32
+		if err := binary.Read(br, binary.BigEndian, &n); err != nil {
+			if errors.Is(err, io.EOF) {
+				break
+			}
+			return nr, err
+		}
+		if int(n) > len(buf) {
+			buf = make([]byte, n)
+		}
+		if _, err := io.ReadFull(br, buf[:n]); err != nil {
+			return nr, err
+		}
+		rec := ReplicationRecord{}
+		if err := rec.Unmarshal(buf[:n]); err != nil {
+			return nr, err
+		}
+		key, err := protocol.DeviceIDFromBytes(rec.Key)
+		if err != nil {
+			key, err = protocol.DeviceIDFromString(string(rec.Key))
+		}
+		if err != nil {
+			log.Println("Bad device ID:", err)
+			continue
+		}
+
+		slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
+		s.m.Store(key, DatabaseRecord{
+			Addresses: expire(rec.Addresses, s.clock.Now()),
+			Seen:      rec.Seen,
+		})
+		nr++
+	}
+	return nr, nil
 }
 
 // merge returns the merged result of the two database records a and b. The
@@ -305,18 +345,9 @@ func (s *levelDBStore) statisticsServe(trigger <-chan struct{}, done chan<- stru
 // chosen for any duplicates.
 func merge(a, b DatabaseRecord) DatabaseRecord {
 	// Both lists must be sorted for this to work.
-	if !sort.IsSorted(databaseAddressOrder(a.Addresses)) {
-		log.Println("Warning: bug: addresses not correctly sorted in merge")
-		a.Addresses = sortedAddressCopy(a.Addresses)
-	}
-	if !sort.IsSorted(databaseAddressOrder(b.Addresses)) {
-		// no warning because this is the side we read from disk and it may
-		// legitimately predate correct sorting.
-		b.Addresses = sortedAddressCopy(b.Addresses)
-	}
 
 	res := DatabaseRecord{
-		Addresses: make([]DatabaseAddress, 0, len(a.Addresses)+len(b.Addresses)),
+		Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))),
 		Seen:      a.Seen,
 	}
 	if b.Seen > a.Seen {
@@ -378,36 +409,21 @@ loop:
 
 // expire returns the list of addresses after removing expired entries.
 // Expiration happen in place, so the slice given as the parameter is
-// destroyed. Internal order is not preserved.
-func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress {
-	i := 0
-	for i < len(addrs) {
-		if addrs[i].Expires < now {
-			addrs = sliceutil.RemoveAndZero(addrs, i)
-			continue
+// destroyed. Internal order is preserved.
+func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
+	cutoff := now.UnixNano()
+	naddrs := addrs[:0]
+	for i := range addrs {
+		if addrs[i].Expires >= cutoff {
+			naddrs = append(naddrs, addrs[i])
 		}
-		i++
 	}
-	return addrs
-}
-
-func sortedAddressCopy(addrs []DatabaseAddress) []DatabaseAddress {
-	sorted := make([]DatabaseAddress, len(addrs))
-	copy(sorted, addrs)
-	sort.Sort(databaseAddressOrder(sorted))
-	return sorted
+	return naddrs
 }
 
-type databaseAddressOrder []DatabaseAddress
-
-func (s databaseAddressOrder) Less(a, b int) bool {
-	return s[a].Address < s[b].Address
-}
-
-func (s databaseAddressOrder) Swap(a, b int) {
-	s[a], s[b] = s[b], s[a]
-}
-
-func (s databaseAddressOrder) Len() int {
-	return len(s)
+func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
+	if c := cmp.Compare(d.Address, other.Address); c != 0 {
+		return c
+	}
+	return cmp.Compare(d.Expires, other.Expires)
 }

+ 26 - 81
cmd/stdiscosrv/database.pb.go

@@ -25,9 +25,7 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
 
 type DatabaseRecord struct {
 	Addresses []DatabaseAddress `protobuf:"bytes,1,rep,name=addresses,proto3" json:"addresses"`
-	Misses    int32             `protobuf:"varint,2,opt,name=misses,proto3" json:"misses,omitempty"`
 	Seen      int64             `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"`
-	Missed    int64             `protobuf:"varint,4,opt,name=missed,proto3" json:"missed,omitempty"`
 }
 
 func (m *DatabaseRecord) Reset()         { *m = DatabaseRecord{} }
@@ -64,7 +62,7 @@ func (m *DatabaseRecord) XXX_DiscardUnknown() {
 var xxx_messageInfo_DatabaseRecord proto.InternalMessageInfo
 
 type ReplicationRecord struct {
-	Key       string            `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
+	Key       []byte            `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
 	Addresses []DatabaseAddress `protobuf:"bytes,2,rep,name=addresses,proto3" json:"addresses"`
 	Seen      int64             `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"`
 }
@@ -149,24 +147,23 @@ func init() {
 func init() { proto.RegisterFile("database.proto", fileDescriptor_b90fe3356ea5df07) }
 
 var fileDescriptor_b90fe3356ea5df07 = []byte{
-	// 270 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x90, 0x41, 0x4a, 0xc4, 0x30,
-	0x18, 0x85, 0x9b, 0x49, 0x1d, 0x99, 0x08, 0xa3, 0x06, 0x94, 0x20, 0x12, 0x4b, 0xdd, 0x74, 0xd5,
-	0x01, 0x5d, 0xb9, 0x74, 0xd0, 0x0b, 0xe4, 0x06, 0xe9, 0xe4, 0x77, 0x08, 0x3a, 0x4d, 0x49, 0x2a,
-	0xe8, 0x29, 0xf4, 0x58, 0x5d, 0xce, 0xd2, 0x95, 0x68, 0x7b, 0x11, 0x69, 0x26, 0x55, 0x14, 0x37,
-	0xb3, 0x7b, 0xdf, 0xff, 0xbf, 0x97, 0xbc, 0x84, 0x4c, 0x95, 0xac, 0x65, 0x21, 0x1d, 0xe4, 0x95,
-	0x35, 0xb5, 0xa1, 0xf1, 0x4a, 0xea, 0xf2, 0xe4, 0xdc, 0x42, 0x65, 0xdc, 0xcc, 0x8f, 0x8a, 0xc7,
-	0xbb, 0xd9, 0xd2, 0x2c, 0x8d, 0x07, 0xaf, 0x36, 0xd6, 0xf4, 0x05, 0x91, 0xe9, 0x4d, 0x48, 0x0b,
-	0x58, 0x18, 0xab, 0xe8, 0x15, 0x99, 0x48, 0xa5, 0x2c, 0x38, 0x07, 0x8e, 0xa1, 0x04, 0x67, 0x7b,
-	0x17, 0x47, 0x79, 0x7f, 0x62, 0x3e, 0x18, 0xaf, 0x37, 0xeb, 0x79, 0xdc, 0xbc, 0x9f, 0x45, 0xe2,
-	0xc7, 0x4d, 0x8f, 0xc9, 0x78, 0xa5, 0x7d, 0x6e, 0x94, 0xa0, 0x6c, 0x47, 0x04, 0xa2, 0x94, 0xc4,
-	0x0e, 0xa0, 0x64, 0x38, 0x41, 0x19, 0x16, 0x5e, 0x7f, 0x7b, 0x15, 0x8b, 0xfd, 0x34, 0x50, 0x5a,
-	0x93, 0x43, 0x01, 0xd5, 0x83, 0x5e, 0xc8, 0x5a, 0x9b, 0x32, 0x74, 0x3a, 0x20, 0xf8, 0x1e, 0x9e,
-	0x19, 0x4a, 0x50, 0x36, 0x11, 0xbd, 0xfc, 0xdd, 0x72, 0xb4, 0x55, 0xcb, 0x7f, 0xda, 0xa4, 0xb7,
-	0x64, 0xff, 0x4f, 0x8e, 0x32, 0xb2, 0x1b, 0x32, 0xe1, 0xde, 0x01, 0xfb, 0x0d, 0x3c, 0x55, 0xda,
-	0x86, 0x77, 0x62, 0x31, 0xe0, 0xfc, 0xb4, 0xf9, 0xe4, 0x51, 0xd3, 0x72, 0xb4, 0x6e, 0x39, 0xfa,
-	0x68, 0x39, 0x7a, 0xed, 0x78, 0xb4, 0xee, 0x78, 0xf4, 0xd6, 0xf1, 0xa8, 0x18, 0xfb, 0x3f, 0xbf,
-	0xfc, 0x0a, 0x00, 0x00, 0xff, 0xff, 0x7a, 0xa2, 0xf6, 0x1e, 0xb0, 0x01, 0x00, 0x00,
+	// 243 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x49, 0x2c, 0x49,
+	0x4c, 0x4a, 0x2c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0xc9, 0x4d, 0xcc, 0xcc,
+	0x93, 0x52, 0x2e, 0x4a, 0x2d, 0xc8, 0x2f, 0xd6, 0x07, 0x0b, 0x25, 0x95, 0xa6, 0xe9, 0xa7, 0xe7,
+	0xa7, 0xe7, 0x83, 0x39, 0x60, 0x16, 0x44, 0xa9, 0x52, 0x3c, 0x17, 0x9f, 0x0b, 0x54, 0x73, 0x50,
+	0x6a, 0x72, 0x7e, 0x51, 0x8a, 0x90, 0x25, 0x17, 0x67, 0x62, 0x4a, 0x4a, 0x51, 0x6a, 0x71, 0x71,
+	0x6a, 0xb1, 0x04, 0xa3, 0x02, 0xb3, 0x06, 0xb7, 0x91, 0xa8, 0x1e, 0xc8, 0x40, 0x3d, 0x98, 0x42,
+	0x47, 0x88, 0xb4, 0x13, 0xcb, 0x89, 0x7b, 0xf2, 0x0c, 0x41, 0x08, 0xd5, 0x42, 0x42, 0x5c, 0x2c,
+	0xc5, 0xa9, 0xa9, 0x79, 0x12, 0xcc, 0x0a, 0x8c, 0x1a, 0xcc, 0x41, 0x60, 0xb6, 0x52, 0x09, 0x97,
+	0x60, 0x50, 0x6a, 0x41, 0x4e, 0x66, 0x72, 0x62, 0x49, 0x66, 0x7e, 0x1e, 0xd4, 0x0e, 0x01, 0x2e,
+	0xe6, 0xec, 0xd4, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d, 0x9e, 0x20, 0x10, 0x13, 0xd5, 0x56, 0x26,
+	0x8a, 0x6d, 0x75, 0xe5, 0xe2, 0x47, 0xd3, 0x27, 0x24, 0xc1, 0xc5, 0x0e, 0xd5, 0x03, 0xb6, 0x97,
+	0x33, 0x08, 0xc6, 0x05, 0xc9, 0xa4, 0x56, 0x14, 0x64, 0x16, 0x81, 0x6d, 0x06, 0x99, 0x01, 0xe3,
+	0x3a, 0xc9, 0x9c, 0x78, 0x28, 0xc7, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f,
+	0x1e, 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c,
+	0x49, 0x6c, 0xe0, 0x20, 0x34, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0xc6, 0x0b, 0x9b, 0x77, 0x7f,
+	0x01, 0x00, 0x00,
 }
 
 func (m *DatabaseRecord) Marshal() (dAtA []byte, err error) {
@@ -189,21 +186,11 @@ func (m *DatabaseRecord) MarshalToSizedBuffer(dAtA []byte) (int, error) {
 	_ = i
 	var l int
 	_ = l
-	if m.Missed != 0 {
-		i = encodeVarintDatabase(dAtA, i, uint64(m.Missed))
-		i--
-		dAtA[i] = 0x20
-	}
 	if m.Seen != 0 {
 		i = encodeVarintDatabase(dAtA, i, uint64(m.Seen))
 		i--
 		dAtA[i] = 0x18
 	}
-	if m.Misses != 0 {
-		i = encodeVarintDatabase(dAtA, i, uint64(m.Misses))
-		i--
-		dAtA[i] = 0x10
-	}
 	if len(m.Addresses) > 0 {
 		for iNdEx := len(m.Addresses) - 1; iNdEx >= 0; iNdEx-- {
 			{
@@ -328,15 +315,9 @@ func (m *DatabaseRecord) Size() (n int) {
 			n += 1 + l + sovDatabase(uint64(l))
 		}
 	}
-	if m.Misses != 0 {
-		n += 1 + sovDatabase(uint64(m.Misses))
-	}
 	if m.Seen != 0 {
 		n += 1 + sovDatabase(uint64(m.Seen))
 	}
-	if m.Missed != 0 {
-		n += 1 + sovDatabase(uint64(m.Missed))
-	}
 	return n
 }
 
@@ -447,25 +428,6 @@ func (m *DatabaseRecord) Unmarshal(dAtA []byte) error {
 				return err
 			}
 			iNdEx = postIndex
-		case 2:
-			if wireType != 0 {
-				return fmt.Errorf("proto: wrong wireType = %d for field Misses", wireType)
-			}
-			m.Misses = 0
-			for shift := uint(0); ; shift += 7 {
-				if shift >= 64 {
-					return ErrIntOverflowDatabase
-				}
-				if iNdEx >= l {
-					return io.ErrUnexpectedEOF
-				}
-				b := dAtA[iNdEx]
-				iNdEx++
-				m.Misses |= int32(b&0x7F) << shift
-				if b < 0x80 {
-					break
-				}
-			}
 		case 3:
 			if wireType != 0 {
 				return fmt.Errorf("proto: wrong wireType = %d for field Seen", wireType)
@@ -485,25 +447,6 @@ func (m *DatabaseRecord) Unmarshal(dAtA []byte) error {
 					break
 				}
 			}
-		case 4:
-			if wireType != 0 {
-				return fmt.Errorf("proto: wrong wireType = %d for field Missed", wireType)
-			}
-			m.Missed = 0
-			for shift := uint(0); ; shift += 7 {
-				if shift >= 64 {
-					return ErrIntOverflowDatabase
-				}
-				if iNdEx >= l {
-					return io.ErrUnexpectedEOF
-				}
-				b := dAtA[iNdEx]
-				iNdEx++
-				m.Missed |= int64(b&0x7F) << shift
-				if b < 0x80 {
-					break
-				}
-			}
 		default:
 			iNdEx = preIndex
 			skippy, err := skipDatabase(dAtA[iNdEx:])
@@ -558,7 +501,7 @@ func (m *ReplicationRecord) Unmarshal(dAtA []byte) error {
 			if wireType != 2 {
 				return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType)
 			}
-			var stringLen uint64
+			var byteLen int
 			for shift := uint(0); ; shift += 7 {
 				if shift >= 64 {
 					return ErrIntOverflowDatabase
@@ -568,23 +511,25 @@ func (m *ReplicationRecord) Unmarshal(dAtA []byte) error {
 				}
 				b := dAtA[iNdEx]
 				iNdEx++
-				stringLen |= uint64(b&0x7F) << shift
+				byteLen |= int(b&0x7F) << shift
 				if b < 0x80 {
 					break
 				}
 			}
-			intStringLen := int(stringLen)
-			if intStringLen < 0 {
+			if byteLen < 0 {
 				return ErrInvalidLengthDatabase
 			}
-			postIndex := iNdEx + intStringLen
+			postIndex := iNdEx + byteLen
 			if postIndex < 0 {
 				return ErrInvalidLengthDatabase
 			}
 			if postIndex > l {
 				return io.ErrUnexpectedEOF
 			}
-			m.Key = string(dAtA[iNdEx:postIndex])
+			m.Key = append(m.Key[:0], dAtA[iNdEx:postIndex]...)
+			if m.Key == nil {
+				m.Key = []byte{}
+			}
 			iNdEx = postIndex
 		case 2:
 			if wireType != 2 {

+ 1 - 5
cmd/stdiscosrv/database.proto

@@ -17,15 +17,11 @@ option (gogoproto.goproto_sizecache_all) = false;
 
 message DatabaseRecord {
     repeated DatabaseAddress addresses = 1 [(gogoproto.nullable) = false];
-    int32                    misses    = 2; // Number of lookups* without hits
     int64                    seen      = 3; // Unix nanos, last device announce
-    int64                    missed    = 4; // Unix nanos, last* failed lookup
 }
 
-// *) Not every lookup results in a write, so may not be completely accurate
-
 message ReplicationRecord {
-    string                   key       = 1;
+    bytes                    key       = 1; // raw 32 byte device ID
     repeated DatabaseAddress addresses = 2 [(gogoproto.nullable) = false];
     int64                    seen      = 3; // Unix nanos, last device announce
 }

+ 12 - 42
cmd/stdiscosrv/database_test.go

@@ -11,29 +11,25 @@ import (
 	"fmt"
 	"testing"
 	"time"
+
+	"github.com/syncthing/syncthing/lib/protocol"
 )
 
 func TestDatabaseGetSet(t *testing.T) {
-	db, err := newMemoryLevelDBStore()
-	if err != nil {
-		t.Fatal(err)
-	}
+	db := newInMemoryStore(t.TempDir(), 0, nil)
 	ctx, cancel := context.WithCancel(context.Background())
 	go db.Serve(ctx)
 	defer cancel()
 
 	// Check missing record
 
-	rec, err := db.get("abcd")
+	rec, err := db.get(&protocol.EmptyDeviceID)
 	if err != nil {
 		t.Error("not found should not be an error")
 	}
 	if len(rec.Addresses) != 0 {
 		t.Error("addresses should be empty")
 	}
-	if rec.Misses != 0 {
-		t.Error("missing should be zero")
-	}
 
 	// Set up a clock
 
@@ -46,13 +42,13 @@ func TestDatabaseGetSet(t *testing.T) {
 	rec.Addresses = []DatabaseAddress{
 		{Address: "tcp://1.2.3.4:5", Expires: tc.Now().Add(time.Minute).UnixNano()},
 	}
-	if err := db.put("abcd", rec); err != nil {
+	if err := db.put(&protocol.EmptyDeviceID, rec); err != nil {
 		t.Fatal(err)
 	}
 
 	// Verify it
 
-	rec, err = db.get("abcd")
+	rec, err = db.get(&protocol.EmptyDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -72,13 +68,13 @@ func TestDatabaseGetSet(t *testing.T) {
 	addrs := []DatabaseAddress{
 		{Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()},
 	}
-	if err := db.merge("abcd", addrs, tc.Now().UnixNano()); err != nil {
+	if err := db.merge(&protocol.EmptyDeviceID, addrs, tc.Now().UnixNano()); err != nil {
 		t.Fatal(err)
 	}
 
 	// Verify it
 
-	rec, err = db.get("abcd")
+	rec, err = db.get(&protocol.EmptyDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -101,7 +97,7 @@ func TestDatabaseGetSet(t *testing.T) {
 
 	// Verify it
 
-	rec, err = db.get("abcd")
+	rec, err = db.get(&protocol.EmptyDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -114,40 +110,18 @@ func TestDatabaseGetSet(t *testing.T) {
 		t.Error("incorrect address")
 	}
 
-	// Put a record with misses
-
-	rec = DatabaseRecord{Misses: 42, Missed: tc.Now().UnixNano()}
-	if err := db.put("efgh", rec); err != nil {
-		t.Fatal(err)
-	}
-
-	// Verify it
-
-	rec, err = db.get("efgh")
-	if err != nil {
-		t.Fatal(err)
-	}
-	if len(rec.Addresses) != 0 {
-		t.Log(rec.Addresses)
-		t.Fatal("should have no addresses")
-	}
-	if rec.Misses != 42 {
-		t.Log(rec.Misses)
-		t.Error("incorrect misses")
-	}
-
 	// Set an address
 
 	addrs = []DatabaseAddress{
 		{Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()},
 	}
-	if err := db.merge("efgh", addrs, tc.Now().UnixNano()); err != nil {
+	if err := db.merge(&protocol.GlobalDeviceID, addrs, tc.Now().UnixNano()); err != nil {
 		t.Fatal(err)
 	}
 
 	// Verify it
 
-	rec, err = db.get("efgh")
+	rec, err = db.get(&protocol.GlobalDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -155,10 +129,6 @@ func TestDatabaseGetSet(t *testing.T) {
 		t.Log(rec.Addresses)
 		t.Fatal("should have one address")
 	}
-	if rec.Misses != 0 {
-		t.Log(rec.Misses)
-		t.Error("should have no misses")
-	}
 }
 
 func TestFilter(t *testing.T) {
@@ -190,7 +160,7 @@ func TestFilter(t *testing.T) {
 	}
 
 	for _, tc := range cases {
-		res := expire(tc.a, 10)
+		res := expire(tc.a, time.Unix(0, 10))
 		if fmt.Sprint(res) != fmt.Sprint(tc.b) {
 			t.Errorf("Incorrect result %v, expected %v", res, tc.b)
 		}

+ 84 - 156
cmd/stdiscosrv/main.go

@@ -9,22 +9,22 @@ package main
 import (
 	"context"
 	"crypto/tls"
-	"flag"
 	"log"
-	"net"
 	"net/http"
 	"os"
+	"os/signal"
 	"runtime"
-	"strings"
 	"time"
 
+	_ "net/http/pprof"
+
+	"github.com/alecthomas/kong"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	_ "github.com/syncthing/syncthing/lib/automaxprocs"
 	"github.com/syncthing/syncthing/lib/build"
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/rand"
 	"github.com/syncthing/syncthing/lib/tlsutil"
-	"github.com/syndtr/goleveldb/leveldb/opt"
 	"github.com/thejerf/suture/v4"
 )
 
@@ -39,17 +39,12 @@ const (
 	errorRetryAfterSeconds = 1500
 	errorRetryFuzzSeconds  = 300
 
-	// Retry for not found is minSeconds + failures * incSeconds +
-	// random(fuzz), where failures is the number of consecutive lookups
-	// with no answer, up to maxSeconds. The fuzz is applied after capping
-	// to maxSeconds.
-	notFoundRetryMinSeconds  = 60
-	notFoundRetryMaxSeconds  = 3540
-	notFoundRetryIncSeconds  = 10
-	notFoundRetryFuzzSeconds = 60
-
-	// How often (in requests) we serialize the missed counter to database.
-	notFoundMissesWriteInterval = 10
+	// Retry for not found is notFoundRetrySeenSeconds for records we have
+	// seen an announcement for (but it's not active right now) and
+	// notFoundRetryUnknownSeconds for records we have never seen (or not
+	// seen within the last week).
+	notFoundRetryUnknownMinSeconds = 60
+	notFoundRetryUnknownMaxSeconds = 3600
 
 	httpReadTimeout    = 5 * time.Second
 	httpWriteTimeout   = 5 * time.Second
@@ -59,184 +54,117 @@ const (
 	replicationOutboxSize = 10000
 )
 
-// These options make the database a little more optimized for writes, at
-// the expense of some memory usage and risk of losing writes in a (system)
-// crash.
-var levelDBOptions = &opt.Options{
-	NoSync:      true,
-	WriteBuffer: 32 << 20, // default 4<<20
-}
-
 var debug = false
 
-func main() {
-	var listen string
-	var dir string
-	var metricsListen string
-	var replicationListen string
-	var replicationPeers string
-	var certFile string
-	var keyFile string
-	var replCertFile string
-	var replKeyFile string
-	var useHTTP bool
-	var largeDB bool
-	var amqpAddress string
-	missesIncrease := 1
+type CLI struct {
+	Cert          string `group:"Listen" help:"Certificate file" default:"./cert.pem" env:"DISCOVERY_CERT_FILE"`
+	Key           string `group:"Listen" help:"Key file" default:"./key.pem" env:"DISCOVERY_KEY_FILE"`
+	HTTP          bool   `group:"Listen" help:"Listen on HTTP (behind an HTTPS proxy)" env:"DISCOVERY_HTTP"`
+	Compression   bool   `group:"Listen" help:"Enable GZIP compression of responses" env:"DISCOVERY_COMPRESSION"`
+	Listen        string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"`
+	MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"`
 
-	log.SetOutput(os.Stdout)
-	log.SetFlags(0)
-
-	flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file")
-	flag.StringVar(&keyFile, "key", "./key.pem", "Key file")
-	flag.StringVar(&dir, "db-dir", "./discovery.db", "Database directory")
-	flag.BoolVar(&debug, "debug", false, "Print debug output")
-	flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)")
-	flag.StringVar(&listen, "listen", ":8443", "Listen address")
-	flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address")
-	flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated")
-	flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address")
-	flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication")
-	flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication")
-	flag.BoolVar(&largeDB, "large-db", false, "Use larger database settings")
-	flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker")
-	flag.IntVar(&missesIncrease, "misses-increase", 1, "How many times to increase the misses counter on each miss")
-	showVersion := flag.Bool("version", false, "Show version")
-	flag.Parse()
+	DBDir           string        `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"`
+	DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"`
 
-	log.Println(build.LongVersionFor("stdiscosrv"))
-	if *showVersion {
-		return
-	}
+	DBS3Endpoint    string `name:"db-s3-endpoint" group:"Database (S3 backup)" hidden:"true" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"`
+	DBS3Region      string `name:"db-s3-region" group:"Database (S3 backup)" hidden:"true" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"`
+	DBS3Bucket      string `name:"db-s3-bucket" group:"Database (S3 backup)" hidden:"true" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"`
+	DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" hidden:"true" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"`
+	DBS3SecretKey   string `name:"db-s3-secret-key" group:"Database (S3 backup)" hidden:"true" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"`
 
-	buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1)
+	AMQPAddress string `group:"AMQP replication" hidden:"true" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"`
 
-	if largeDB {
-		levelDBOptions.BlockCacheCapacity = 64 << 20
-		levelDBOptions.BlockSize = 64 << 10
-		levelDBOptions.CompactionTableSize = 16 << 20
-		levelDBOptions.CompactionTableSizeMultiplier = 2.0
-		levelDBOptions.WriteBuffer = 64 << 20
-		levelDBOptions.CompactionL0Trigger = 8
-	}
+	Debug   bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"`
+	Version bool `short:"v" help:"Print version and exit"`
+}
 
-	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
-	if os.IsNotExist(err) {
-		log.Println("Failed to load keypair. Generating one, this might take a while...")
-		cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 20*365)
-		if err != nil {
-			log.Fatalln("Failed to generate X509 key pair:", err)
-		}
-	} else if err != nil {
-		log.Fatalln("Failed to load keypair:", err)
-	}
-	devID := protocol.NewDeviceID(cert.Certificate[0])
-	log.Println("Server device ID is", devID)
+func main() {
+	log.SetOutput(os.Stdout)
 
-	replCert := cert
-	if replCertFile != "" && replKeyFile != "" {
-		replCert, err = tls.LoadX509KeyPair(replCertFile, replKeyFile)
-		if err != nil {
-			log.Fatalln("Failed to load replication keypair:", err)
-		}
+	var cli CLI
+	kong.Parse(&cli)
+	debug = cli.Debug
+
+	log.Println(build.LongVersionFor("stdiscosrv"))
+	if cli.Version {
+		return
 	}
-	replDevID := protocol.NewDeviceID(replCert.Certificate[0])
-	log.Println("Replication device ID is", replDevID)
-
-	// Parse the replication specs, if any.
-	var allowedReplicationPeers []protocol.DeviceID
-	var replicationDestinations []string
-	parts := strings.Split(replicationPeers, ",")
-	for _, part := range parts {
-		if part == "" {
-			continue
-		}
 
-		fields := strings.Split(part, "@")
-		switch len(fields) {
-		case 2:
-			// This is an id@address specification. Grab the address for the
-			// destination list. Try to resolve it once to catch obvious
-			// syntax errors here rather than having the sender service fail
-			// repeatedly later.
-			_, err := net.ResolveTCPAddr("tcp", fields[1])
-			if err != nil {
-				log.Fatalln("Resolving address:", err)
-			}
-			replicationDestinations = append(replicationDestinations, fields[1])
-			fallthrough // N.B.
+	buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1)
 
-		case 1:
-			// The first part is always a device ID.
-			id, err := protocol.DeviceIDFromString(fields[0])
+	var cert tls.Certificate
+	if !cli.HTTP {
+		var err error
+		cert, err = tls.LoadX509KeyPair(cli.Cert, cli.Key)
+		if os.IsNotExist(err) {
+			log.Println("Failed to load keypair. Generating one, this might take a while...")
+			cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365)
 			if err != nil {
-				log.Fatalln("Parsing device ID:", err)
+				log.Fatalln("Failed to generate X509 key pair:", err)
 			}
-			if id == protocol.EmptyDeviceID {
-				log.Fatalf("Missing device ID for peer in %q", part)
-			}
-			allowedReplicationPeers = append(allowedReplicationPeers, id)
-
-		default:
-			log.Fatalln("Unrecognized replication spec:", part)
+		} else if err != nil {
+			log.Fatalln("Failed to load keypair:", err)
 		}
+		devID := protocol.NewDeviceID(cert.Certificate[0])
+		log.Println("Server device ID is", devID)
 	}
 
 	// Root of the service tree.
 	main := suture.New("main", suture.Spec{
 		PassThroughPanics: true,
+		Timeout:           2 * time.Minute,
 	})
 
-	// Start the database.
-	db, err := newLevelDBStore(dir)
-	if err != nil {
-		log.Fatalln("Open database:", err)
-	}
-	main.Add(db)
-
-	// Start any replication senders.
-	var repl replicationMultiplexer
-	for _, dst := range replicationDestinations {
-		rs := newReplicationSender(dst, replCert, allowedReplicationPeers)
-		main.Add(rs)
-		repl = append(repl, rs)
+	// If configured, use S3 for database backups.
+	var s3c *s3Copier
+	if cli.DBS3Endpoint != "" {
+		hostname, err := os.Hostname()
+		if err != nil {
+			log.Fatalf("Failed to get hostname: %v", err)
+		}
+		key := hostname + ".db"
+		s3c = newS3Copier(cli.DBS3Endpoint, cli.DBS3Region, cli.DBS3Bucket, key, cli.DBS3AccessKeyID, cli.DBS3SecretKey)
 	}
 
-	// If we have replication configured, start the replication listener.
-	if len(allowedReplicationPeers) > 0 {
-		rl := newReplicationListener(replicationListen, replCert, allowedReplicationPeers, db)
-		main.Add(rl)
-	}
+	// Start the database.
+	db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c)
+	main.Add(db)
 
-	// If we have an AMQP broker, start that
-	if amqpAddress != "" {
+	// If we have an AMQP broker for replication, start that
+	var repl replicator
+	if cli.AMQPAddress != "" {
 		clientID := rand.String(10)
-		kr := newAMQPReplicator(amqpAddress, clientID, db)
-		repl = append(repl, kr)
+		kr := newAMQPReplicator(cli.AMQPAddress, clientID, db)
 		main.Add(kr)
+		repl = kr
 	}
 
-	go func() {
-		for range time.NewTicker(time.Second).C {
-			for _, r := range repl {
-				r.send("<heartbeat>", nil, time.Now().UnixNano())
-			}
-		}
-	}()
-
 	// Start the main API server.
-	qs := newAPISrv(listen, cert, db, repl, useHTTP, missesIncrease)
+	qs := newAPISrv(cli.Listen, cert, db, repl, cli.HTTP, cli.Compression)
 	main.Add(qs)
 
 	// If we have a metrics port configured, start a metrics handler.
-	if metricsListen != "" {
+	if cli.MetricsListen != "" {
 		go func() {
 			mux := http.NewServeMux()
 			mux.Handle("/metrics", promhttp.Handler())
-			log.Fatal(http.ListenAndServe(metricsListen, mux))
+			log.Fatal(http.ListenAndServe(cli.MetricsListen, mux))
 		}()
 	}
 
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	// Cancel on signal
+	signalChan := make(chan os.Signal, 1)
+	signal.Notify(signalChan, os.Interrupt)
+	go func() {
+		sig := <-signalChan
+		log.Printf("Received signal %s; shutting down", sig)
+		cancel()
+	}()
+
 	// Engage!
-	main.Serve(context.Background())
+	main.Serve(ctx)
 }

+ 0 - 325
cmd/stdiscosrv/replication.go

@@ -1,325 +0,0 @@
-// Copyright (C) 2018 The Syncthing Authors.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at https://mozilla.org/MPL/2.0/.
-
-package main
-
-import (
-	"context"
-	"crypto/tls"
-	"encoding/binary"
-	"fmt"
-	io "io"
-	"log"
-	"net"
-	"time"
-
-	"github.com/syncthing/syncthing/lib/protocol"
-)
-
-const (
-	replicationReadTimeout       = time.Minute
-	replicationWriteTimeout      = 30 * time.Second
-	replicationHeartbeatInterval = time.Second * 30
-)
-
-type replicator interface {
-	send(key string, addrs []DatabaseAddress, seen int64)
-}
-
-// a replicationSender tries to connect to the remote address and provide
-// them with a feed of replication updates.
-type replicationSender struct {
-	dst        string
-	cert       tls.Certificate // our certificate
-	allowedIDs []protocol.DeviceID
-	outbox     chan ReplicationRecord
-}
-
-func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
-	return &replicationSender{
-		dst:        dst,
-		cert:       cert,
-		allowedIDs: allowedIDs,
-		outbox:     make(chan ReplicationRecord, replicationOutboxSize),
-	}
-}
-
-func (s *replicationSender) Serve(ctx context.Context) error {
-	// Sleep a little at startup. Peers often restart at the same time, and
-	// this avoid the service failing and entering backoff state
-	// unnecessarily, while also reducing the reconnect rate to something
-	// reasonable by default.
-	time.Sleep(2 * time.Second)
-
-	tlsCfg := &tls.Config{
-		Certificates:       []tls.Certificate{s.cert},
-		MinVersion:         tls.VersionTLS12,
-		InsecureSkipVerify: true,
-	}
-
-	// Dial the TLS connection.
-	conn, err := tls.Dial("tcp", s.dst, tlsCfg)
-	if err != nil {
-		log.Println("Replication connect:", err)
-		return err
-	}
-	defer func() {
-		conn.SetWriteDeadline(time.Now().Add(time.Second))
-		conn.Close()
-	}()
-
-	// The replication stream is not especially latency sensitive, but it is
-	// quite a lot of data in small writes. Make it more efficient.
-	if tcpc, ok := conn.NetConn().(*net.TCPConn); ok {
-		_ = tcpc.SetNoDelay(false)
-	}
-
-	// Get the other side device ID.
-	remoteID, err := deviceID(conn)
-	if err != nil {
-		log.Println("Replication connect:", err)
-		return err
-	}
-
-	// Verify it's in the set of allowed device IDs.
-	if !deviceIDIn(remoteID, s.allowedIDs) {
-		log.Println("Replication connect: unexpected device ID:", remoteID)
-		return err
-	}
-
-	heartBeatTicker := time.NewTicker(replicationHeartbeatInterval)
-	defer heartBeatTicker.Stop()
-
-	// Send records.
-	buf := make([]byte, 1024)
-	for {
-		select {
-		case <-heartBeatTicker.C:
-			if len(s.outbox) > 0 {
-				// No need to send heartbeats if there are events/prevrious
-				// heartbeats to send, they will keep the connection alive.
-				continue
-			}
-			// Empty replication message is the heartbeat:
-			s.outbox <- ReplicationRecord{}
-
-		case rec := <-s.outbox:
-			// Buffer must hold record plus four bytes for size
-			size := rec.Size()
-			if len(buf) < size+4 {
-				buf = make([]byte, size+4)
-			}
-
-			// Record comes after the four bytes size
-			n, err := rec.MarshalTo(buf[4:])
-			if err != nil {
-				// odd to get an error here, but we haven't sent anything
-				// yet so it's not fatal
-				replicationSendsTotal.WithLabelValues("error").Inc()
-				log.Println("Replication marshal:", err)
-				continue
-			}
-			binary.BigEndian.PutUint32(buf, uint32(n))
-
-			// Send
-			conn.SetWriteDeadline(time.Now().Add(replicationWriteTimeout))
-			if _, err := conn.Write(buf[:4+n]); err != nil {
-				replicationSendsTotal.WithLabelValues("error").Inc()
-				log.Println("Replication write:", err)
-				// Yes, we are losing the replication event here.
-				return err
-			}
-			replicationSendsTotal.WithLabelValues("success").Inc()
-
-		case <-ctx.Done():
-			return nil
-		}
-	}
-}
-
-func (s *replicationSender) String() string {
-	return fmt.Sprintf("replicationSender(%q)", s.dst)
-}
-
-func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) {
-	item := ReplicationRecord{
-		Key:       key,
-		Addresses: ps,
-		Seen:      seen,
-	}
-
-	// The send should never block. The inbox is suitably buffered for at
-	// least a few seconds of stalls, which shouldn't happen in practice.
-	select {
-	case s.outbox <- item:
-	default:
-		replicationSendsTotal.WithLabelValues("drop").Inc()
-	}
-}
-
-// a replicationMultiplexer sends to multiple replicators
-type replicationMultiplexer []replicator
-
-func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
-	for _, s := range m {
-		// each send is nonblocking
-		s.send(key, ps, seen)
-	}
-}
-
-// replicationListener accepts incoming connections and reads replication
-// items from them. Incoming items are applied to the KV store.
-type replicationListener struct {
-	addr       string
-	cert       tls.Certificate
-	allowedIDs []protocol.DeviceID
-	db         database
-}
-
-func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
-	return &replicationListener{
-		addr:       addr,
-		cert:       cert,
-		allowedIDs: allowedIDs,
-		db:         db,
-	}
-}
-
-func (l *replicationListener) Serve(ctx context.Context) error {
-	tlsCfg := &tls.Config{
-		Certificates:       []tls.Certificate{l.cert},
-		ClientAuth:         tls.RequestClientCert,
-		MinVersion:         tls.VersionTLS12,
-		InsecureSkipVerify: true,
-	}
-
-	lst, err := tls.Listen("tcp", l.addr, tlsCfg)
-	if err != nil {
-		log.Println("Replication listen:", err)
-		return err
-	}
-	defer lst.Close()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return nil
-		default:
-		}
-
-		// Accept a connection
-		conn, err := lst.Accept()
-		if err != nil {
-			log.Println("Replication accept:", err)
-			return err
-		}
-
-		// Figure out the other side device ID
-		remoteID, err := deviceID(conn.(*tls.Conn))
-		if err != nil {
-			log.Println("Replication accept:", err)
-			conn.SetWriteDeadline(time.Now().Add(time.Second))
-			conn.Close()
-			continue
-		}
-
-		// Verify it is in the set of allowed device IDs
-		if !deviceIDIn(remoteID, l.allowedIDs) {
-			log.Println("Replication accept: unexpected device ID:", remoteID)
-			conn.SetWriteDeadline(time.Now().Add(time.Second))
-			conn.Close()
-			continue
-		}
-
-		go l.handle(ctx, conn)
-	}
-}
-
-func (l *replicationListener) String() string {
-	return fmt.Sprintf("replicationListener(%q)", l.addr)
-}
-
-func (l *replicationListener) handle(ctx context.Context, conn net.Conn) {
-	defer func() {
-		conn.SetWriteDeadline(time.Now().Add(time.Second))
-		conn.Close()
-	}()
-
-	buf := make([]byte, 1024)
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		default:
-		}
-
-		conn.SetReadDeadline(time.Now().Add(replicationReadTimeout))
-
-		// First four bytes are the size
-		if _, err := io.ReadFull(conn, buf[:4]); err != nil {
-			log.Println("Replication read size:", err)
-			replicationRecvsTotal.WithLabelValues("error").Inc()
-			return
-		}
-
-		// Read the rest of the record
-		size := int(binary.BigEndian.Uint32(buf[:4]))
-		if len(buf) < size {
-			buf = make([]byte, size)
-		}
-
-		if size == 0 {
-			// Heartbeat, ignore
-			continue
-		}
-
-		if _, err := io.ReadFull(conn, buf[:size]); err != nil {
-			log.Println("Replication read record:", err)
-			replicationRecvsTotal.WithLabelValues("error").Inc()
-			return
-		}
-
-		// Unmarshal
-		var rec ReplicationRecord
-		if err := rec.Unmarshal(buf[:size]); err != nil {
-			log.Println("Replication unmarshal:", err)
-			replicationRecvsTotal.WithLabelValues("error").Inc()
-			continue
-		}
-
-		// Store
-		l.db.merge(rec.Key, rec.Addresses, rec.Seen)
-		replicationRecvsTotal.WithLabelValues("success").Inc()
-	}
-}
-
-func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
-	// Handshake may not be complete on the server side yet, which we need
-	// to get the client certificate.
-	if !conn.ConnectionState().HandshakeComplete {
-		if err := conn.Handshake(); err != nil {
-			return protocol.DeviceID{}, err
-		}
-	}
-
-	// We expect exactly one certificate.
-	certs := conn.ConnectionState().PeerCertificates
-	if len(certs) != 1 {
-		return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
-	}
-
-	return protocol.NewDeviceID(certs[0].Raw), nil
-}
-
-func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
-	for _, candidate := range ids {
-		if id == candidate {
-			return true
-		}
-	}
-	return false
-}

+ 97 - 0
cmd/stdiscosrv/s3.go

@@ -0,0 +1,97 @@
+// Copyright (C) 2024 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package main
+
+import (
+	"io"
+	"log"
+	"time"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/credentials"
+	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/aws/aws-sdk-go/service/s3/s3manager"
+)
+
+type s3Copier struct {
+	endpoint    string
+	region      string
+	bucket      string
+	key         string
+	accessKeyID string
+	secretKey   string
+}
+
+func newS3Copier(endpoint, region, bucket, key, accessKeyID, secretKey string) *s3Copier {
+	return &s3Copier{
+		endpoint:    endpoint,
+		region:      region,
+		bucket:      bucket,
+		key:         key,
+		accessKeyID: accessKeyID,
+		secretKey:   secretKey,
+	}
+}
+
+func (s *s3Copier) upload(r io.Reader) error {
+	sess, err := session.NewSession(&aws.Config{
+		Region:      aws.String(s.region),
+		Endpoint:    aws.String(s.endpoint),
+		Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
+	})
+	if err != nil {
+		return err
+	}
+
+	uploader := s3manager.NewUploader(sess)
+	_, err = uploader.Upload(&s3manager.UploadInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(s.key),
+		Body:   r,
+	})
+	return err
+}
+
+func (s *s3Copier) downloadLatest(w io.WriterAt) error {
+	sess, err := session.NewSession(&aws.Config{
+		Region:      aws.String(s.region),
+		Endpoint:    aws.String(s.endpoint),
+		Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
+	})
+	if err != nil {
+		return err
+	}
+
+	svc := s3.New(sess)
+	resp, err := svc.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: aws.String(s.bucket)})
+	if err != nil {
+		return err
+	}
+
+	var lastKey string
+	var lastModified time.Time
+	var lastSize int64
+	for _, item := range resp.Contents {
+		if item.LastModified.After(lastModified) && *item.Size > lastSize {
+			lastKey = *item.Key
+			lastModified = *item.LastModified
+			lastSize = *item.Size
+		} else if lastModified.Sub(*item.LastModified) < 5*time.Minute && *item.Size > lastSize {
+			lastKey = *item.Key
+			lastSize = *item.Size
+		}
+	}
+
+	log.Println("Downloading database from", lastKey)
+	downloader := s3manager.NewDownloader(sess)
+	_, err = downloader.Download(w, &s3.GetObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(lastKey),
+	})
+	return err
+}

+ 24 - 8
cmd/stdiscosrv/stats.go

@@ -96,13 +96,28 @@ var (
 			Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
 		}, []string{"operation"})
 
-	retryAfterHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
-		Namespace: "syncthing",
-		Subsystem: "discovery",
-		Name:      "retry_after_seconds",
-		Help:      "Retry-After header value in seconds.",
-		Buckets:   prometheus.ExponentialBuckets(60, 2, 7), // 60, 120, 240, 480, 960, 1920, 3840
-	})
+	databaseWriteSeconds = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Namespace: "syncthing",
+			Subsystem: "discovery",
+			Name:      "database_write_seconds",
+			Help:      "Time spent writing the database.",
+		})
+	databaseLastWritten = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Namespace: "syncthing",
+			Subsystem: "discovery",
+			Name:      "database_last_written",
+			Help:      "Timestamp of the last successful database write.",
+		})
+
+	retryAfterLevel = prometheus.NewGaugeVec(
+		prometheus.GaugeOpts{
+			Namespace: "syncthing",
+			Subsystem: "discovery",
+			Name:      "retry_after_seconds",
+			Help:      "Retry-After header value in seconds.",
+		}, []string{"name"})
 )
 
 const (
@@ -123,5 +138,6 @@ func init() {
 		replicationSendsTotal, replicationRecvsTotal,
 		databaseKeys, databaseStatisticsSeconds,
 		databaseOperations, databaseOperationSeconds,
-		retryAfterHistogram)
+		databaseWriteSeconds, databaseLastWritten,
+		retryAfterLevel)
 }

+ 3 - 0
go.mod

@@ -5,6 +5,7 @@ go 1.22.0
 require (
 	github.com/AudriusButkevicius/recli v0.0.7-0.20220911121932-d000ce8fbf0f
 	github.com/alecthomas/kong v0.9.0
+	github.com/aws/aws-sdk-go v1.55.5
 	github.com/calmh/incontainer v1.0.0
 	github.com/calmh/xdr v1.1.0
 	github.com/ccding/go-stun v0.1.5
@@ -28,6 +29,7 @@ require (
 	github.com/oschwald/geoip2-golang v1.11.0
 	github.com/pierrec/lz4/v4 v4.1.21
 	github.com/prometheus/client_golang v1.19.1
+	github.com/puzpuzpuz/xsync/v3 v3.4.0
 	github.com/quic-go/quic-go v0.46.0
 	github.com/rabbitmq/amqp091-go v1.10.0
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
@@ -67,6 +69,7 @@ require (
 	github.com/google/uuid v1.6.0 // indirect
 	github.com/hashicorp/errwrap v1.1.0 // indirect
 	github.com/hashicorp/go-multierror v1.1.1 // indirect
+	github.com/jmespath/go-jmespath v0.4.0 // indirect
 	github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
 	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/nxadm/tail v1.4.11 // indirect

+ 10 - 0
go.sum

@@ -11,6 +11,8 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc
 github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
+github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
+github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/calmh/glob v0.0.0-20220615080505-1d823af5017b h1:Fjm4GuJ+TGMgqfGHN42IQArJb77CfD/mAwLbDUoJe6g=
@@ -124,6 +126,10 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
 github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
 github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
 github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
 github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
@@ -194,6 +200,8 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G
 github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
 github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
 github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
+github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
+github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
 github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y=
 github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
 github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw=
@@ -381,7 +389,9 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
 gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=