database.go 10 KB


  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. //go:generate go run ../../proto/scripts/protofmt.go database.proto
  7. //go:generate protoc -I ../../ -I . --gogofast_out=. database.proto
  8. package main
  9. import (
  10. "bufio"
  11. "cmp"
  12. "context"
  13. "encoding/binary"
  14. "errors"
  15. "io"
  16. "log"
  17. "os"
  18. "path"
  19. "runtime"
  20. "slices"
  21. "strings"
  22. "time"
  23. "github.com/puzpuzpuz/xsync/v3"
  24. "github.com/syncthing/syncthing/lib/protocol"
  25. )
  26. type clock interface {
  27. Now() time.Time
  28. }
  29. type defaultClock struct{}
  30. func (defaultClock) Now() time.Time {
  31. return time.Now()
  32. }
  33. type database interface {
  34. put(key *protocol.DeviceID, rec DatabaseRecord) error
  35. merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error
  36. get(key *protocol.DeviceID) (DatabaseRecord, error)
  37. }
  38. type inMemoryStore struct {
  39. m *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
  40. dir string
  41. flushInterval time.Duration
  42. s3 *s3Copier
  43. clock clock
  44. }
  45. func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore {
  46. s := &inMemoryStore{
  47. m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
  48. dir: dir,
  49. flushInterval: flushInterval,
  50. s3: s3,
  51. clock: defaultClock{},
  52. }
  53. nr, err := s.read()
  54. if os.IsNotExist(err) && s3 != nil {
  55. // Try to read from AWS
  56. fd, cerr := os.Create(path.Join(s.dir, "records.db"))
  57. if cerr != nil {
  58. log.Println("Error creating database file:", err)
  59. return s
  60. }
  61. if err := s3.downloadLatest(fd); err != nil {
  62. log.Printf("Error reading database from S3: %v", err)
  63. }
  64. _ = fd.Close()
  65. nr, err = s.read()
  66. }
  67. if err != nil {
  68. log.Println("Error reading database:", err)
  69. }
  70. log.Printf("Read %d records from database", nr)
  71. s.expireAndCalculateStatistics()
  72. return s
  73. }
  74. func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error {
  75. t0 := time.Now()
  76. s.m.Store(*key, rec)
  77. databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
  78. databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
  79. return nil
  80. }
  81. func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error {
  82. t0 := time.Now()
  83. newRec := DatabaseRecord{
  84. Addresses: addrs,
  85. Seen: seen,
  86. }
  87. oldRec, _ := s.m.Load(*key)
  88. newRec = merge(oldRec, newRec)
  89. s.m.Store(*key, newRec)
  90. databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
  91. databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
  92. return nil
  93. }
  94. func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
  95. t0 := time.Now()
  96. defer func() {
  97. databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
  98. }()
  99. rec, ok := s.m.Load(*key)
  100. if !ok {
  101. databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
  102. return DatabaseRecord{}, nil
  103. }
  104. rec.Addresses = expire(rec.Addresses, s.clock.Now())
  105. databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
  106. return rec, nil
  107. }
  108. func (s *inMemoryStore) Serve(ctx context.Context) error {
  109. if s.flushInterval <= 0 {
  110. <-ctx.Done()
  111. return nil
  112. }
  113. t := time.NewTimer(s.flushInterval)
  114. defer t.Stop()
  115. loop:
  116. for {
  117. select {
  118. case <-t.C:
  119. log.Println("Calculating statistics")
  120. s.expireAndCalculateStatistics()
  121. log.Println("Flushing database")
  122. if err := s.write(); err != nil {
  123. log.Println("Error writing database:", err)
  124. }
  125. log.Println("Finished flushing database")
  126. t.Reset(s.flushInterval)
  127. case <-ctx.Done():
  128. // We're done.
  129. break loop
  130. }
  131. }
  132. return s.write()
  133. }
  134. func (s *inMemoryStore) expireAndCalculateStatistics() {
  135. now := s.clock.Now()
  136. cutoff24h := now.Add(-24 * time.Hour).UnixNano()
  137. cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
  138. current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0
  139. n := 0
  140. s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
  141. if n%1000 == 0 {
  142. runtime.Gosched()
  143. }
  144. n++
  145. addresses := expire(rec.Addresses, now)
  146. if len(addresses) == 0 {
  147. rec.Addresses = nil
  148. s.m.Store(key, rec)
  149. } else if len(addresses) != len(rec.Addresses) {
  150. rec.Addresses = addresses
  151. s.m.Store(key, rec)
  152. }
  153. switch {
  154. case len(rec.Addresses) > 0:
  155. current++
  156. seenIPv4, seenIPv6, seenIPv6GUA := false, false, false
  157. for _, addr := range rec.Addresses {
  158. // We do fast and loose matching on strings here instead of
  159. // parsing the address and the IP and doing "proper" checks,
  160. // to keep things fast and generate less garbage.
  161. if strings.Contains(addr.Address, "[") {
  162. seenIPv6 = true
  163. if strings.Contains(addr.Address, "[2") {
  164. seenIPv6GUA = true
  165. }
  166. } else {
  167. seenIPv4 = true
  168. }
  169. if seenIPv4 && seenIPv6 && seenIPv6GUA {
  170. break
  171. }
  172. }
  173. if seenIPv4 {
  174. currentIPv4++
  175. }
  176. if seenIPv6 {
  177. currentIPv6++
  178. }
  179. if seenIPv6GUA {
  180. currentIPv6GUA++
  181. }
  182. case rec.Seen > cutoff24h:
  183. last24h++
  184. case rec.Seen > cutoff1w:
  185. last1w++
  186. default:
  187. // drop the record if it's older than a week
  188. s.m.Delete(key)
  189. }
  190. return true
  191. })
  192. databaseKeys.WithLabelValues("current").Set(float64(current))
  193. databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
  194. databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
  195. databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA))
  196. databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
  197. databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
  198. databaseStatisticsSeconds.Set(time.Since(now).Seconds())
  199. }
  200. func (s *inMemoryStore) write() (err error) {
  201. t0 := time.Now()
  202. defer func() {
  203. if err == nil {
  204. databaseWriteSeconds.Set(time.Since(t0).Seconds())
  205. databaseLastWritten.Set(float64(t0.Unix()))
  206. }
  207. }()
  208. dbf := path.Join(s.dir, "records.db")
  209. fd, err := os.Create(dbf + ".tmp")
  210. if err != nil {
  211. return err
  212. }
  213. bw := bufio.NewWriter(fd)
  214. var buf []byte
  215. var rangeErr error
  216. now := s.clock.Now()
  217. cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
  218. n := 0
  219. s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
  220. if n%1000 == 0 {
  221. runtime.Gosched()
  222. }
  223. n++
  224. if value.Seen < cutoff1w {
  225. // drop the record if it's older than a week
  226. return true
  227. }
  228. rec := ReplicationRecord{
  229. Key: key[:],
  230. Addresses: value.Addresses,
  231. Seen: value.Seen,
  232. }
  233. s := rec.Size()
  234. if s+4 > len(buf) {
  235. buf = make([]byte, s+4)
  236. }
  237. n, err := rec.MarshalTo(buf[4:])
  238. if err != nil {
  239. rangeErr = err
  240. return false
  241. }
  242. binary.BigEndian.PutUint32(buf, uint32(n))
  243. if _, err := bw.Write(buf[:n+4]); err != nil {
  244. rangeErr = err
  245. return false
  246. }
  247. return true
  248. })
  249. if rangeErr != nil {
  250. _ = fd.Close()
  251. return rangeErr
  252. }
  253. if err := bw.Flush(); err != nil {
  254. _ = fd.Close
  255. return err
  256. }
  257. if err := fd.Close(); err != nil {
  258. return err
  259. }
  260. if err := os.Rename(dbf+".tmp", dbf); err != nil {
  261. return err
  262. }
  263. // Upload to S3
  264. if s.s3 != nil {
  265. fd, err = os.Open(dbf)
  266. if err != nil {
  267. log.Printf("Error uploading database to S3: %v", err)
  268. return nil
  269. }
  270. defer fd.Close()
  271. if err := s.s3.upload(fd); err != nil {
  272. log.Printf("Error uploading database to S3: %v", err)
  273. }
  274. log.Println("Finished uploading database")
  275. }
  276. return nil
  277. }
  278. func (s *inMemoryStore) read() (int, error) {
  279. fd, err := os.Open(path.Join(s.dir, "records.db"))
  280. if err != nil {
  281. return 0, err
  282. }
  283. defer fd.Close()
  284. br := bufio.NewReader(fd)
  285. var buf []byte
  286. nr := 0
  287. for {
  288. var n uint32
  289. if err := binary.Read(br, binary.BigEndian, &n); err != nil {
  290. if errors.Is(err, io.EOF) {
  291. break
  292. }
  293. return nr, err
  294. }
  295. if int(n) > len(buf) {
  296. buf = make([]byte, n)
  297. }
  298. if _, err := io.ReadFull(br, buf[:n]); err != nil {
  299. return nr, err
  300. }
  301. rec := ReplicationRecord{}
  302. if err := rec.Unmarshal(buf[:n]); err != nil {
  303. return nr, err
  304. }
  305. key, err := protocol.DeviceIDFromBytes(rec.Key)
  306. if err != nil {
  307. key, err = protocol.DeviceIDFromString(string(rec.Key))
  308. }
  309. if err != nil {
  310. log.Println("Bad device ID:", err)
  311. continue
  312. }
  313. slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
  314. rec.Addresses = slices.CompactFunc(rec.Addresses, DatabaseAddress.Equal)
  315. s.m.Store(key, DatabaseRecord{
  316. Addresses: expire(rec.Addresses, s.clock.Now()),
  317. Seen: rec.Seen,
  318. })
  319. nr++
  320. }
  321. return nr, nil
  322. }
  323. // merge returns the merged result of the two database records a and b. The
  324. // result is the union of the two address sets, with the newer expiry time
  325. // chosen for any duplicates. The address list in a is overwritten and
  326. // reused for the result.
  327. func merge(a, b DatabaseRecord) DatabaseRecord {
  328. // Both lists must be sorted for this to work.
  329. a.Seen = max(a.Seen, b.Seen)
  330. aIdx := 0
  331. bIdx := 0
  332. for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) {
  333. switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) {
  334. case 0:
  335. // a == b, choose the newer expiry time
  336. a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires)
  337. aIdx++
  338. bIdx++
  339. case -1:
  340. // a < b, keep a and move on
  341. aIdx++
  342. case 1:
  343. // a > b, insert b before a
  344. a.Addresses = append(a.Addresses[:aIdx], append([]DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...)
  345. bIdx++
  346. }
  347. }
  348. if bIdx < len(b.Addresses) {
  349. a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...)
  350. }
  351. return a
  352. }
  353. // expire returns the list of addresses after removing expired entries.
  354. // Expiration happen in place, so the slice given as the parameter is
  355. // destroyed. Internal order is preserved.
  356. func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
  357. cutoff := now.UnixNano()
  358. naddrs := addrs[:0]
  359. for i := range addrs {
  360. if i > 0 && addrs[i].Address == addrs[i-1].Address {
  361. // Skip duplicates
  362. continue
  363. }
  364. if addrs[i].Expires >= cutoff {
  365. naddrs = append(naddrs, addrs[i])
  366. }
  367. }
  368. if len(naddrs) == 0 {
  369. return nil
  370. }
  371. return naddrs
  372. }
  373. func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
  374. if c := cmp.Compare(d.Address, other.Address); c != 0 {
  375. return c
  376. }
  377. return cmp.Compare(d.Expires, other.Expires)
  378. }
  379. func (d DatabaseAddress) Equal(other DatabaseAddress) bool {
  380. return d.Address == other.Address
  381. }