123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- // Copyright (C) 2018 The Syncthing Authors.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at https://mozilla.org/MPL/2.0/.
- package main
- import (
- "bufio"
- "cmp"
- "context"
- "encoding/binary"
- "errors"
- "io"
- "log"
- "os"
- "path"
- "runtime"
- "slices"
- "strings"
- "time"
- "github.com/puzpuzpuz/xsync/v3"
- "google.golang.org/protobuf/proto"
- "github.com/syncthing/syncthing/internal/gen/discosrv"
- "github.com/syncthing/syncthing/internal/protoutil"
- "github.com/syncthing/syncthing/lib/protocol"
- "github.com/syncthing/syncthing/lib/rand"
- "github.com/syncthing/syncthing/lib/s3"
- )
- type clock interface {
- Now() time.Time
- }
- type defaultClock struct{}
- func (defaultClock) Now() time.Time {
- return time.Now()
- }
- type database interface {
- put(key *protocol.DeviceID, rec *discosrv.DatabaseRecord) error
- merge(key *protocol.DeviceID, addrs []*discosrv.DatabaseAddress, seen int64) error
- get(key *protocol.DeviceID) (*discosrv.DatabaseRecord, error)
- }
- type inMemoryStore struct {
- m *xsync.MapOf[protocol.DeviceID, *discosrv.DatabaseRecord]
- dir string
- flushInterval time.Duration
- s3 *s3.Session
- objKey string
- clock clock
- }
- func newInMemoryStore(dir string, flushInterval time.Duration, s3sess *s3.Session) *inMemoryStore {
- hn, err := os.Hostname()
- if err != nil {
- hn = rand.String(8)
- }
- s := &inMemoryStore{
- m: xsync.NewMapOf[protocol.DeviceID, *discosrv.DatabaseRecord](),
- dir: dir,
- flushInterval: flushInterval,
- s3: s3sess,
- objKey: hn + ".db",
- clock: defaultClock{},
- }
- nr, err := s.read()
- if os.IsNotExist(err) && s3sess != nil {
- // Try to read from AWS
- latestKey, cerr := s3sess.LatestKey()
- if cerr != nil {
- log.Println("Error reading database from S3:", err)
- return s
- }
- fd, cerr := os.Create(path.Join(s.dir, "records.db"))
- if cerr != nil {
- log.Println("Error creating database file:", err)
- return s
- }
- if cerr := s3sess.Download(fd, latestKey); cerr != nil {
- log.Printf("Error reading database from S3: %v", err)
- }
- _ = fd.Close()
- nr, err = s.read()
- }
- if err != nil {
- log.Println("Error reading database:", err)
- }
- log.Printf("Read %d records from database", nr)
- s.expireAndCalculateStatistics()
- return s
- }
- func (s *inMemoryStore) put(key *protocol.DeviceID, rec *discosrv.DatabaseRecord) error {
- t0 := time.Now()
- s.m.Store(*key, rec)
- databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
- databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
- return nil
- }
- func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []*discosrv.DatabaseAddress, seen int64) error {
- t0 := time.Now()
- newRec := &discosrv.DatabaseRecord{
- Addresses: addrs,
- Seen: seen,
- }
- if oldRec, ok := s.m.Load(*key); ok {
- newRec = merge(oldRec, newRec)
- }
- s.m.Store(*key, newRec)
- databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
- databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
- return nil
- }
- func (s *inMemoryStore) get(key *protocol.DeviceID) (*discosrv.DatabaseRecord, error) {
- t0 := time.Now()
- defer func() {
- databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
- }()
- rec, ok := s.m.Load(*key)
- if !ok {
- databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
- return &discosrv.DatabaseRecord{}, nil
- }
- rec.Addresses = expire(rec.Addresses, s.clock.Now())
- databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
- return rec, nil
- }
- func (s *inMemoryStore) Serve(ctx context.Context) error {
- if s.flushInterval <= 0 {
- <-ctx.Done()
- return nil
- }
- t := time.NewTimer(s.flushInterval)
- defer t.Stop()
- loop:
- for {
- select {
- case <-t.C:
- log.Println("Calculating statistics")
- s.expireAndCalculateStatistics()
- log.Println("Flushing database")
- if err := s.write(); err != nil {
- log.Println("Error writing database:", err)
- }
- log.Println("Finished flushing database")
- t.Reset(s.flushInterval)
- case <-ctx.Done():
- // We're done.
- break loop
- }
- }
- return s.write()
- }
- func (s *inMemoryStore) expireAndCalculateStatistics() {
- now := s.clock.Now()
- cutoff24h := now.Add(-24 * time.Hour).UnixNano()
- cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
- current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0
- n := 0
- s.m.Range(func(key protocol.DeviceID, rec *discosrv.DatabaseRecord) bool {
- if n%1000 == 0 {
- runtime.Gosched()
- }
- n++
- addresses := expire(rec.Addresses, now)
- if len(addresses) == 0 {
- rec.Addresses = nil
- s.m.Store(key, rec)
- } else if len(addresses) != len(rec.Addresses) {
- rec.Addresses = addresses
- s.m.Store(key, rec)
- }
- switch {
- case len(rec.Addresses) > 0:
- current++
- seenIPv4, seenIPv6, seenIPv6GUA := false, false, false
- for _, addr := range rec.Addresses {
- // We do fast and loose matching on strings here instead of
- // parsing the address and the IP and doing "proper" checks,
- // to keep things fast and generate less garbage.
- if strings.Contains(addr.Address, "[") {
- seenIPv6 = true
- if strings.Contains(addr.Address, "[2") {
- seenIPv6GUA = true
- }
- } else {
- seenIPv4 = true
- }
- if seenIPv4 && seenIPv6 && seenIPv6GUA {
- break
- }
- }
- if seenIPv4 {
- currentIPv4++
- }
- if seenIPv6 {
- currentIPv6++
- }
- if seenIPv6GUA {
- currentIPv6GUA++
- }
- case rec.Seen > cutoff24h:
- last24h++
- case rec.Seen > cutoff1w:
- last1w++
- default:
- // drop the record if it's older than a week
- s.m.Delete(key)
- }
- return true
- })
- databaseKeys.WithLabelValues("current").Set(float64(current))
- databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
- databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
- databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA))
- databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
- databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
- databaseStatisticsSeconds.Set(time.Since(now).Seconds())
- }
- func (s *inMemoryStore) write() (err error) {
- t0 := time.Now()
- defer func() {
- if err == nil {
- databaseWriteSeconds.Set(time.Since(t0).Seconds())
- databaseLastWritten.Set(float64(t0.Unix()))
- }
- }()
- dbf := path.Join(s.dir, "records.db")
- fd, err := os.Create(dbf + ".tmp")
- if err != nil {
- return err
- }
- bw := bufio.NewWriter(fd)
- var buf []byte
- var rangeErr error
- now := s.clock.Now()
- cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
- n := 0
- s.m.Range(func(key protocol.DeviceID, value *discosrv.DatabaseRecord) bool {
- if n%1000 == 0 {
- runtime.Gosched()
- }
- n++
- if value.Seen < cutoff1w {
- // drop the record if it's older than a week
- return true
- }
- rec := &discosrv.ReplicationRecord{
- Key: key[:],
- Addresses: value.Addresses,
- Seen: value.Seen,
- }
- s := proto.Size(rec)
- if s+4 > len(buf) {
- buf = make([]byte, s+4)
- }
- n, err := protoutil.MarshalTo(buf[4:], rec)
- if err != nil {
- rangeErr = err
- return false
- }
- binary.BigEndian.PutUint32(buf, uint32(n))
- if _, err := bw.Write(buf[:n+4]); err != nil {
- rangeErr = err
- return false
- }
- return true
- })
- if rangeErr != nil {
- _ = fd.Close()
- return rangeErr
- }
- if err := bw.Flush(); err != nil {
- _ = fd.Close
- return err
- }
- if err := fd.Close(); err != nil {
- return err
- }
- if err := os.Rename(dbf+".tmp", dbf); err != nil {
- return err
- }
- // Upload to S3
- if s.s3 != nil {
- fd, err = os.Open(dbf)
- if err != nil {
- log.Printf("Error uploading database to S3: %v", err)
- return nil
- }
- defer fd.Close()
- if err := s.s3.Upload(fd, s.objKey); err != nil {
- log.Printf("Error uploading database to S3: %v", err)
- }
- log.Println("Finished uploading database")
- }
- return nil
- }
- func (s *inMemoryStore) read() (int, error) {
- fd, err := os.Open(path.Join(s.dir, "records.db"))
- if err != nil {
- return 0, err
- }
- defer fd.Close()
- br := bufio.NewReader(fd)
- var buf []byte
- nr := 0
- for {
- var n uint32
- if err := binary.Read(br, binary.BigEndian, &n); err != nil {
- if errors.Is(err, io.EOF) {
- break
- }
- return nr, err
- }
- if int(n) > len(buf) {
- buf = make([]byte, n)
- }
- if _, err := io.ReadFull(br, buf[:n]); err != nil {
- return nr, err
- }
- rec := &discosrv.ReplicationRecord{}
- if err := proto.Unmarshal(buf[:n], rec); err != nil {
- return nr, err
- }
- key, err := protocol.DeviceIDFromBytes(rec.Key)
- if err != nil {
- key, err = protocol.DeviceIDFromString(string(rec.Key))
- }
- if err != nil {
- log.Println("Bad device ID:", err)
- continue
- }
- slices.SortFunc(rec.Addresses, Cmp)
- rec.Addresses = slices.CompactFunc(rec.Addresses, Equal)
- s.m.Store(key, &discosrv.DatabaseRecord{
- Addresses: expire(rec.Addresses, s.clock.Now()),
- Seen: rec.Seen,
- })
- nr++
- }
- return nr, nil
- }
- // merge returns the merged result of the two database records a and b. The
- // result is the union of the two address sets, with the newer expiry time
- // chosen for any duplicates. The address list in a is overwritten and
- // reused for the result.
- func merge(a, b *discosrv.DatabaseRecord) *discosrv.DatabaseRecord {
- // Both lists must be sorted for this to work.
- a.Seen = max(a.Seen, b.Seen)
- aIdx := 0
- bIdx := 0
- for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) {
- switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) {
- case 0:
- // a == b, choose the newer expiry time
- a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires)
- aIdx++
- bIdx++
- case -1:
- // a < b, keep a and move on
- aIdx++
- case 1:
- // a > b, insert b before a
- a.Addresses = append(a.Addresses[:aIdx], append([]*discosrv.DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...)
- bIdx++
- }
- }
- if bIdx < len(b.Addresses) {
- a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...)
- }
- return a
- }
- // expire returns the list of addresses after removing expired entries.
- // Expiration happen in place, so the slice given as the parameter is
- // destroyed. Internal order is preserved.
- func expire(addrs []*discosrv.DatabaseAddress, now time.Time) []*discosrv.DatabaseAddress {
- cutoff := now.UnixNano()
- naddrs := addrs[:0]
- for i := range addrs {
- if i > 0 && addrs[i].Address == addrs[i-1].Address {
- // Skip duplicates
- continue
- }
- if addrs[i].Expires >= cutoff {
- naddrs = append(naddrs, addrs[i])
- }
- }
- if len(naddrs) == 0 {
- return nil
- }
- return naddrs
- }
- func Cmp(d, other *discosrv.DatabaseAddress) (n int) {
- if c := cmp.Compare(d.Address, other.Address); c != 0 {
- return c
- }
- return cmp.Compare(d.Expires, other.Expires)
- }
- func Equal(d, other *discosrv.DatabaseAddress) bool {
- return d.Address == other.Address
- }
|