database.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. //go:generate go run ../../proto/scripts/protofmt.go database.proto
  7. //go:generate protoc -I ../../ -I . --gogofast_out=. database.proto
  8. package main
  9. import (
  10. "bufio"
  11. "cmp"
  12. "context"
  13. "encoding/binary"
  14. "io"
  15. "log"
  16. "net"
  17. "net/url"
  18. "os"
  19. "path"
  20. "slices"
  21. "time"
  22. "github.com/aws/aws-sdk-go/aws"
  23. "github.com/aws/aws-sdk-go/aws/session"
  24. "github.com/aws/aws-sdk-go/service/s3"
  25. "github.com/aws/aws-sdk-go/service/s3/s3manager"
  26. "github.com/puzpuzpuz/xsync/v3"
  27. "github.com/syncthing/syncthing/lib/protocol"
  28. )
  29. type clock interface {
  30. Now() time.Time
  31. }
  32. type defaultClock struct{}
  33. func (defaultClock) Now() time.Time {
  34. return time.Now()
  35. }
  36. type database interface {
  37. put(key *protocol.DeviceID, rec DatabaseRecord) error
  38. merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error
  39. get(key *protocol.DeviceID) (DatabaseRecord, error)
  40. }
  41. type inMemoryStore struct {
  42. m *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
  43. dir string
  44. flushInterval time.Duration
  45. clock clock
  46. }
  47. func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore {
  48. s := &inMemoryStore{
  49. m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
  50. dir: dir,
  51. flushInterval: flushInterval,
  52. clock: defaultClock{},
  53. }
  54. err := s.read()
  55. if os.IsNotExist(err) {
  56. // Try to read from AWS
  57. fd, cerr := os.Create(path.Join(s.dir, "records.db"))
  58. if cerr != nil {
  59. log.Println("Error creating database file:", err)
  60. return s
  61. }
  62. if err := s3Download(fd); err != nil {
  63. log.Printf("Error reading database from S3: %v", err)
  64. }
  65. _ = fd.Close()
  66. err = s.read()
  67. }
  68. if err != nil {
  69. log.Println("Error reading database:", err)
  70. }
  71. s.calculateStatistics()
  72. return s
  73. }
  74. func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error {
  75. t0 := time.Now()
  76. s.m.Store(*key, rec)
  77. databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
  78. databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
  79. return nil
  80. }
  81. func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error {
  82. t0 := time.Now()
  83. newRec := DatabaseRecord{
  84. Addresses: addrs,
  85. Seen: seen,
  86. }
  87. oldRec, _ := s.m.Load(*key)
  88. newRec = merge(newRec, oldRec)
  89. s.m.Store(*key, newRec)
  90. databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
  91. databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
  92. return nil
  93. }
  94. func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
  95. t0 := time.Now()
  96. defer func() {
  97. databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
  98. }()
  99. rec, ok := s.m.Load(*key)
  100. if !ok {
  101. databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
  102. return DatabaseRecord{}, nil
  103. }
  104. rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano())
  105. databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
  106. return rec, nil
  107. }
  108. func (s *inMemoryStore) Serve(ctx context.Context) error {
  109. t := time.NewTimer(s.flushInterval)
  110. defer t.Stop()
  111. if s.flushInterval <= 0 {
  112. t.Stop()
  113. }
  114. loop:
  115. for {
  116. select {
  117. case <-t.C:
  118. if err := s.write(); err != nil {
  119. log.Println("Error writing database:", err)
  120. }
  121. s.calculateStatistics()
  122. t.Reset(s.flushInterval)
  123. case <-ctx.Done():
  124. // We're done.
  125. break loop
  126. }
  127. }
  128. return s.write()
  129. }
  130. func (s *inMemoryStore) calculateStatistics() {
  131. t0 := time.Now()
  132. nowNanos := t0.UnixNano()
  133. cutoff24h := t0.Add(-24 * time.Hour).UnixNano()
  134. cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano()
  135. current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0
  136. s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
  137. // If there are addresses that have not expired it's a current
  138. // record, otherwise account it based on when it was last seen
  139. // (last 24 hours or last week) or finally as inactice.
  140. addrs := expire(rec.Addresses, nowNanos)
  141. switch {
  142. case len(addrs) > 0:
  143. current++
  144. seenIPv4, seenIPv6 := false, false
  145. for _, addr := range addrs {
  146. uri, err := url.Parse(addr.Address)
  147. if err != nil {
  148. continue
  149. }
  150. host, _, err := net.SplitHostPort(uri.Host)
  151. if err != nil {
  152. continue
  153. }
  154. if ip := net.ParseIP(host); ip != nil && ip.To4() != nil {
  155. seenIPv4 = true
  156. } else if ip != nil {
  157. seenIPv6 = true
  158. }
  159. if seenIPv4 && seenIPv6 {
  160. break
  161. }
  162. }
  163. if seenIPv4 {
  164. currentIPv4++
  165. }
  166. if seenIPv6 {
  167. currentIPv6++
  168. }
  169. case rec.Seen > cutoff24h:
  170. last24h++
  171. case rec.Seen > cutoff1w:
  172. last1w++
  173. default:
  174. // drop the record if it's older than a week
  175. s.m.Delete(key)
  176. }
  177. return true
  178. })
  179. databaseKeys.WithLabelValues("current").Set(float64(current))
  180. databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
  181. databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
  182. databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
  183. databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
  184. databaseKeys.WithLabelValues("error").Set(float64(errors))
  185. databaseStatisticsSeconds.Set(time.Since(t0).Seconds())
  186. }
  187. func (s *inMemoryStore) write() (err error) {
  188. t0 := time.Now()
  189. log.Println("Writing database")
  190. defer func() {
  191. log.Println("Finished writing database")
  192. if err == nil {
  193. databaseWriteSeconds.Set(time.Since(t0).Seconds())
  194. databaseLastWritten.Set(float64(t0.Unix()))
  195. }
  196. }()
  197. dbf := path.Join(s.dir, "records.db")
  198. fd, err := os.Create(dbf + ".tmp")
  199. if err != nil {
  200. return err
  201. }
  202. bw := bufio.NewWriter(fd)
  203. var buf []byte
  204. var rangeErr error
  205. now := s.clock.Now().UnixNano()
  206. cutoff1w := s.clock.Now().Add(-7 * 24 * time.Hour).UnixNano()
  207. s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
  208. if value.Seen < cutoff1w {
  209. // drop the record if it's older than a week
  210. return true
  211. }
  212. rec := ReplicationRecord{
  213. Key: key[:],
  214. Addresses: expire(value.Addresses, now),
  215. Seen: value.Seen,
  216. }
  217. s := rec.Size()
  218. if s+4 > len(buf) {
  219. buf = make([]byte, s+4)
  220. }
  221. n, err := rec.MarshalTo(buf[4:])
  222. if err != nil {
  223. rangeErr = err
  224. return false
  225. }
  226. binary.BigEndian.PutUint32(buf, uint32(n))
  227. if _, err := bw.Write(buf[:n+4]); err != nil {
  228. rangeErr = err
  229. return false
  230. }
  231. return true
  232. })
  233. if rangeErr != nil {
  234. _ = fd.Close()
  235. return rangeErr
  236. }
  237. if err := bw.Flush(); err != nil {
  238. _ = fd.Close
  239. return err
  240. }
  241. if err := fd.Close(); err != nil {
  242. return err
  243. }
  244. if err := os.Rename(dbf+".tmp", dbf); err != nil {
  245. return err
  246. }
  247. if os.Getenv("PODINDEX") == "0" {
  248. // Upload to S3
  249. log.Println("Uploading database")
  250. fd, err = os.Open(dbf)
  251. if err != nil {
  252. log.Printf("Error uploading database to S3: %v", err)
  253. return nil
  254. }
  255. defer fd.Close()
  256. if err := s3Upload(fd); err != nil {
  257. log.Printf("Error uploading database to S3: %v", err)
  258. }
  259. log.Println("Finished uploading database")
  260. }
  261. return nil
  262. }
  263. func (s *inMemoryStore) read() error {
  264. fd, err := os.Open(path.Join(s.dir, "records.db"))
  265. if err != nil {
  266. return err
  267. }
  268. defer fd.Close()
  269. br := bufio.NewReader(fd)
  270. var buf []byte
  271. for {
  272. var n uint32
  273. if err := binary.Read(br, binary.BigEndian, &n); err != nil {
  274. if err == io.EOF {
  275. break
  276. }
  277. return err
  278. }
  279. if int(n) > len(buf) {
  280. buf = make([]byte, n)
  281. }
  282. if _, err := io.ReadFull(br, buf[:n]); err != nil {
  283. return err
  284. }
  285. rec := ReplicationRecord{}
  286. if err := rec.Unmarshal(buf[:n]); err != nil {
  287. return err
  288. }
  289. key, err := protocol.DeviceIDFromBytes(rec.Key)
  290. if err != nil {
  291. key, err = protocol.DeviceIDFromString(string(rec.Key))
  292. }
  293. if err != nil {
  294. log.Println("Bad device ID:", err)
  295. continue
  296. }
  297. slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
  298. s.m.Store(key, DatabaseRecord{
  299. Addresses: rec.Addresses,
  300. Seen: rec.Seen,
  301. })
  302. }
  303. return nil
  304. }
  305. // merge returns the merged result of the two database records a and b. The
  306. // result is the union of the two address sets, with the newer expiry time
  307. // chosen for any duplicates.
  308. func merge(a, b DatabaseRecord) DatabaseRecord {
  309. // Both lists must be sorted for this to work.
  310. res := DatabaseRecord{
  311. Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))),
  312. Seen: a.Seen,
  313. }
  314. if b.Seen > a.Seen {
  315. res.Seen = b.Seen
  316. }
  317. aIdx := 0
  318. bIdx := 0
  319. aAddrs := a.Addresses
  320. bAddrs := b.Addresses
  321. loop:
  322. for {
  323. switch {
  324. case aIdx == len(aAddrs) && bIdx == len(bAddrs):
  325. // both lists are exhausted, we are done
  326. break loop
  327. case aIdx == len(aAddrs):
  328. // a is exhausted, pick from b and continue
  329. res.Addresses = append(res.Addresses, bAddrs[bIdx])
  330. bIdx++
  331. continue
  332. case bIdx == len(bAddrs):
  333. // b is exhausted, pick from a and continue
  334. res.Addresses = append(res.Addresses, aAddrs[aIdx])
  335. aIdx++
  336. continue
  337. }
  338. // We have values left on both sides.
  339. aVal := aAddrs[aIdx]
  340. bVal := bAddrs[bIdx]
  341. switch {
  342. case aVal.Address == bVal.Address:
  343. // update for same address, pick newer
  344. if aVal.Expires > bVal.Expires {
  345. res.Addresses = append(res.Addresses, aVal)
  346. } else {
  347. res.Addresses = append(res.Addresses, bVal)
  348. }
  349. aIdx++
  350. bIdx++
  351. case aVal.Address < bVal.Address:
  352. // a is smallest, pick it and continue
  353. res.Addresses = append(res.Addresses, aVal)
  354. aIdx++
  355. default:
  356. // b is smallest, pick it and continue
  357. res.Addresses = append(res.Addresses, bVal)
  358. bIdx++
  359. }
  360. }
  361. return res
  362. }
  363. // expire returns the list of addresses after removing expired entries.
  364. // Expiration happen in place, so the slice given as the parameter is
  365. // destroyed. Internal order is preserved.
  366. func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress {
  367. i := 0
  368. for i < len(addrs) {
  369. if addrs[i].Expires < now {
  370. copy(addrs[i:], addrs[i+1:])
  371. addrs[len(addrs)-1] = DatabaseAddress{}
  372. addrs = addrs[:len(addrs)-1]
  373. continue
  374. }
  375. i++
  376. }
  377. return addrs
  378. }
  379. func s3Upload(r io.Reader) error {
  380. sess, err := session.NewSession(&aws.Config{
  381. Region: aws.String("fr-par"),
  382. Endpoint: aws.String("s3.fr-par.scw.cloud"),
  383. })
  384. if err != nil {
  385. return err
  386. }
  387. uploader := s3manager.NewUploader(sess)
  388. _, err = uploader.Upload(&s3manager.UploadInput{
  389. Bucket: aws.String("syncthing-discovery"),
  390. Key: aws.String("discovery.db"),
  391. Body: r,
  392. })
  393. return err
  394. }
  395. func s3Download(w io.WriterAt) error {
  396. sess, err := session.NewSession(&aws.Config{
  397. Region: aws.String("fr-par"),
  398. Endpoint: aws.String("s3.fr-par.scw.cloud"),
  399. })
  400. if err != nil {
  401. return err
  402. }
  403. downloader := s3manager.NewDownloader(sess)
  404. _, err = downloader.Download(w, &s3.GetObjectInput{
  405. Bucket: aws.String("syncthing-discovery"),
  406. Key: aws.String("discovery.db"),
  407. })
  408. return err
  409. }
  410. func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
  411. if c := cmp.Compare(d.Address, other.Address); c != 0 {
  412. return c
  413. }
  414. return cmp.Compare(d.Expires, other.Expires)
  415. }