database.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. //go:generate go run ../../proto/scripts/protofmt.go database.proto
  7. //go:generate protoc -I ../../ -I . --gogofast_out=. database.proto
  8. package main
  9. import (
  10. "bufio"
  11. "cmp"
  12. "context"
  13. "encoding/binary"
  14. "errors"
  15. "io"
  16. "log"
  17. "os"
  18. "path"
  19. "runtime"
  20. "slices"
  21. "strings"
  22. "time"
  23. "github.com/aws/aws-sdk-go/aws"
  24. "github.com/aws/aws-sdk-go/aws/session"
  25. "github.com/aws/aws-sdk-go/service/s3"
  26. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  27. "github.com/puzpuzpuz/xsync/v3"
  28. "github.com/syncthing/syncthing/lib/protocol"
  29. )
  30. type clock interface {
  31. Now() time.Time
  32. }
  33. type defaultClock struct{}
  34. func (defaultClock) Now() time.Time {
  35. return time.Now()
  36. }
  37. type database interface {
  38. put(key *protocol.DeviceID, rec DatabaseRecord) error
  39. merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error
  40. get(key *protocol.DeviceID) (DatabaseRecord, error)
  41. }
  42. type inMemoryStore struct {
  43. m *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
  44. dir string
  45. flushInterval time.Duration
  46. clock clock
  47. }
  48. func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore {
  49. s := &inMemoryStore{
  50. m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
  51. dir: dir,
  52. flushInterval: flushInterval,
  53. clock: defaultClock{},
  54. }
  55. nr, err := s.read()
  56. if os.IsNotExist(err) {
  57. // Try to read from AWS
  58. fd, cerr := os.Create(path.Join(s.dir, "records.db"))
  59. if cerr != nil {
  60. log.Println("Error creating database file:", err)
  61. return s
  62. }
  63. if err := s3Download(fd); err != nil {
  64. log.Printf("Error reading database from S3: %v", err)
  65. }
  66. _ = fd.Close()
  67. nr, err = s.read()
  68. }
  69. if err != nil {
  70. log.Println("Error reading database:", err)
  71. }
  72. log.Printf("Read %d records from database", nr)
  73. s.calculateStatistics()
  74. return s
  75. }
  76. func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error {
  77. t0 := time.Now()
  78. s.m.Store(*key, rec)
  79. databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
  80. databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
  81. return nil
  82. }
  83. func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error {
  84. t0 := time.Now()
  85. newRec := DatabaseRecord{
  86. Addresses: addrs,
  87. Seen: seen,
  88. }
  89. oldRec, _ := s.m.Load(*key)
  90. newRec = merge(newRec, oldRec)
  91. s.m.Store(*key, newRec)
  92. databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
  93. databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
  94. return nil
  95. }
  96. func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
  97. t0 := time.Now()
  98. defer func() {
  99. databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
  100. }()
  101. rec, ok := s.m.Load(*key)
  102. if !ok {
  103. databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
  104. return DatabaseRecord{}, nil
  105. }
  106. rec.Addresses = expire(rec.Addresses, s.clock.Now())
  107. databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
  108. return rec, nil
  109. }
  110. func (s *inMemoryStore) Serve(ctx context.Context) error {
  111. t := time.NewTimer(s.flushInterval)
  112. defer t.Stop()
  113. if s.flushInterval <= 0 {
  114. t.Stop()
  115. }
  116. loop:
  117. for {
  118. select {
  119. case <-t.C:
  120. log.Println("Calculating statistics")
  121. s.calculateStatistics()
  122. log.Println("Flushing database")
  123. if err := s.write(); err != nil {
  124. log.Println("Error writing database:", err)
  125. }
  126. log.Println("Finished flushing database")
  127. t.Reset(s.flushInterval)
  128. case <-ctx.Done():
  129. // We're done.
  130. break loop
  131. }
  132. }
  133. return s.write()
  134. }
  135. func (s *inMemoryStore) calculateStatistics() {
  136. now := s.clock.Now()
  137. cutoff24h := now.Add(-24 * time.Hour).UnixNano()
  138. cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
  139. current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0
  140. n := 0
  141. s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
  142. if n%1000 == 0 {
  143. runtime.Gosched()
  144. }
  145. n++
  146. addresses := expire(rec.Addresses, now)
  147. switch {
  148. case len(addresses) > 0:
  149. current++
  150. seenIPv4, seenIPv6 := false, false
  151. for _, addr := range rec.Addresses {
  152. if strings.Contains(addr.Address, "[") {
  153. seenIPv6 = true
  154. } else {
  155. seenIPv4 = true
  156. }
  157. if seenIPv4 && seenIPv6 {
  158. break
  159. }
  160. }
  161. if seenIPv4 {
  162. currentIPv4++
  163. }
  164. if seenIPv6 {
  165. currentIPv6++
  166. }
  167. case rec.Seen > cutoff24h:
  168. last24h++
  169. case rec.Seen > cutoff1w:
  170. last1w++
  171. default:
  172. // drop the record if it's older than a week
  173. s.m.Delete(key)
  174. }
  175. return true
  176. })
  177. databaseKeys.WithLabelValues("current").Set(float64(current))
  178. databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
  179. databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
  180. databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
  181. databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
  182. databaseStatisticsSeconds.Set(time.Since(now).Seconds())
  183. }
  184. func (s *inMemoryStore) write() (err error) {
  185. t0 := time.Now()
  186. defer func() {
  187. if err == nil {
  188. databaseWriteSeconds.Set(time.Since(t0).Seconds())
  189. databaseLastWritten.Set(float64(t0.Unix()))
  190. }
  191. }()
  192. dbf := path.Join(s.dir, "records.db")
  193. fd, err := os.Create(dbf + ".tmp")
  194. if err != nil {
  195. return err
  196. }
  197. bw := bufio.NewWriter(fd)
  198. var buf []byte
  199. var rangeErr error
  200. now := s.clock.Now()
  201. cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
  202. n := 0
  203. s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
  204. if n%1000 == 0 {
  205. runtime.Gosched()
  206. }
  207. n++
  208. if value.Seen < cutoff1w {
  209. // drop the record if it's older than a week
  210. return true
  211. }
  212. rec := ReplicationRecord{
  213. Key: key[:],
  214. Addresses: value.Addresses,
  215. Seen: value.Seen,
  216. }
  217. s := rec.Size()
  218. if s+4 > len(buf) {
  219. buf = make([]byte, s+4)
  220. }
  221. n, err := rec.MarshalTo(buf[4:])
  222. if err != nil {
  223. rangeErr = err
  224. return false
  225. }
  226. binary.BigEndian.PutUint32(buf, uint32(n))
  227. if _, err := bw.Write(buf[:n+4]); err != nil {
  228. rangeErr = err
  229. return false
  230. }
  231. return true
  232. })
  233. if rangeErr != nil {
  234. _ = fd.Close()
  235. return rangeErr
  236. }
  237. if err := bw.Flush(); err != nil {
  238. _ = fd.Close
  239. return err
  240. }
  241. if err := fd.Close(); err != nil {
  242. return err
  243. }
  244. if err := os.Rename(dbf+".tmp", dbf); err != nil {
  245. return err
  246. }
  247. if os.Getenv("PODINDEX") == "0" {
  248. // Upload to S3
  249. log.Println("Uploading database")
  250. fd, err = os.Open(dbf)
  251. if err != nil {
  252. log.Printf("Error uploading database to S3: %v", err)
  253. return nil
  254. }
  255. defer fd.Close()
  256. if err := s3Upload(fd); err != nil {
  257. log.Printf("Error uploading database to S3: %v", err)
  258. }
  259. log.Println("Finished uploading database")
  260. }
  261. return nil
  262. }
  263. func (s *inMemoryStore) read() (int, error) {
  264. fd, err := os.Open(path.Join(s.dir, "records.db"))
  265. if err != nil {
  266. return 0, err
  267. }
  268. defer fd.Close()
  269. br := bufio.NewReader(fd)
  270. var buf []byte
  271. nr := 0
  272. for {
  273. var n uint32
  274. if err := binary.Read(br, binary.BigEndian, &n); err != nil {
  275. if errors.Is(err, io.EOF) {
  276. break
  277. }
  278. return nr, err
  279. }
  280. if int(n) > len(buf) {
  281. buf = make([]byte, n)
  282. }
  283. if _, err := io.ReadFull(br, buf[:n]); err != nil {
  284. return nr, err
  285. }
  286. rec := ReplicationRecord{}
  287. if err := rec.Unmarshal(buf[:n]); err != nil {
  288. return nr, err
  289. }
  290. key, err := protocol.DeviceIDFromBytes(rec.Key)
  291. if err != nil {
  292. key, err = protocol.DeviceIDFromString(string(rec.Key))
  293. }
  294. if err != nil {
  295. log.Println("Bad device ID:", err)
  296. continue
  297. }
  298. slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
  299. s.m.Store(key, DatabaseRecord{
  300. Addresses: expire(rec.Addresses, s.clock.Now()),
  301. Seen: rec.Seen,
  302. })
  303. nr++
  304. }
  305. return nr, nil
  306. }
  307. // merge returns the merged result of the two database records a and b. The
  308. // result is the union of the two address sets, with the newer expiry time
  309. // chosen for any duplicates.
  310. func merge(a, b DatabaseRecord) DatabaseRecord {
  311. // Both lists must be sorted for this to work.
  312. res := DatabaseRecord{
  313. Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))),
  314. Seen: a.Seen,
  315. }
  316. if b.Seen > a.Seen {
  317. res.Seen = b.Seen
  318. }
  319. aIdx := 0
  320. bIdx := 0
  321. aAddrs := a.Addresses
  322. bAddrs := b.Addresses
  323. loop:
  324. for {
  325. switch {
  326. case aIdx == len(aAddrs) && bIdx == len(bAddrs):
  327. // both lists are exhausted, we are done
  328. break loop
  329. case aIdx == len(aAddrs):
  330. // a is exhausted, pick from b and continue
  331. res.Addresses = append(res.Addresses, bAddrs[bIdx])
  332. bIdx++
  333. continue
  334. case bIdx == len(bAddrs):
  335. // b is exhausted, pick from a and continue
  336. res.Addresses = append(res.Addresses, aAddrs[aIdx])
  337. aIdx++
  338. continue
  339. }
  340. // We have values left on both sides.
  341. aVal := aAddrs[aIdx]
  342. bVal := bAddrs[bIdx]
  343. switch {
  344. case aVal.Address == bVal.Address:
  345. // update for same address, pick newer
  346. if aVal.Expires > bVal.Expires {
  347. res.Addresses = append(res.Addresses, aVal)
  348. } else {
  349. res.Addresses = append(res.Addresses, bVal)
  350. }
  351. aIdx++
  352. bIdx++
  353. case aVal.Address < bVal.Address:
  354. // a is smallest, pick it and continue
  355. res.Addresses = append(res.Addresses, aVal)
  356. aIdx++
  357. default:
  358. // b is smallest, pick it and continue
  359. res.Addresses = append(res.Addresses, bVal)
  360. bIdx++
  361. }
  362. }
  363. return res
  364. }
  365. // expire returns the list of addresses after removing expired entries.
  366. // Expiration happen in place, so the slice given as the parameter is
  367. // destroyed. Internal order is preserved.
  368. func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
  369. cutoff := now.UnixNano()
  370. naddrs := addrs[:0]
  371. for i := range addrs {
  372. if addrs[i].Expires >= cutoff {
  373. naddrs = append(naddrs, addrs[i])
  374. }
  375. }
  376. return naddrs
  377. }
  378. func s3Upload(r io.Reader) error {
  379. sess, err := session.NewSession(&aws.Config{
  380. Region: aws.String("fr-par"),
  381. Endpoint: aws.String("s3.fr-par.scw.cloud"),
  382. })
  383. if err != nil {
  384. return err
  385. }
  386. uploader := s3manager.NewUploader(sess)
  387. _, err = uploader.Upload(&s3manager.UploadInput{
  388. Bucket: aws.String("syncthing-discovery"),
  389. Key: aws.String("discovery.db"),
  390. Body: r,
  391. })
  392. return err
  393. }
  394. func s3Download(w io.WriterAt) error {
  395. sess, err := session.NewSession(&aws.Config{
  396. Region: aws.String("fr-par"),
  397. Endpoint: aws.String("s3.fr-par.scw.cloud"),
  398. })
  399. if err != nil {
  400. return err
  401. }
  402. downloader := s3manager.NewDownloader(sess)
  403. _, err = downloader.Download(w, &s3.GetObjectInput{
  404. Bucket: aws.String("syncthing-discovery"),
  405. Key: aws.String("discovery.db"),
  406. })
  407. return err
  408. }
  409. func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
  410. if c := cmp.Compare(d.Address, other.Address); c != 0 {
  411. return c
  412. }
  413. return cmp.Compare(d.Expires, other.Expires)
  414. }