浏览代码

chore(stdiscosrv): improve expire, logging

Jakob Borg 1 年之前
父节点
当前提交
f3f5557c8e
共有 2 个文件被更改,包括 31 次插入33 次删除
  1. 30 32
      cmd/stdiscosrv/database.go
  2. 1 1
      cmd/stdiscosrv/database_test.go

+ 30 - 32
cmd/stdiscosrv/database.go

@@ -63,7 +63,7 @@ func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore {
 		flushInterval: flushInterval,
 		clock:         defaultClock{},
 	}
-	err := s.read()
+	nr, err := s.read()
 	if os.IsNotExist(err) {
 		// Try to read from AWS
 		fd, cerr := os.Create(path.Join(s.dir, "records.db"))
@@ -75,11 +75,12 @@ func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore {
 			log.Printf("Error reading database from S3: %v", err)
 		}
 		_ = fd.Close()
-		err = s.read()
+		nr, err = s.read()
 	}
 	if err != nil {
 		log.Println("Error reading database:", err)
 	}
+	log.Printf("Read %d records from database", nr)
 	s.calculateStatistics()
 	return s
 }
@@ -122,7 +123,7 @@ func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
 		return DatabaseRecord{}, nil
 	}
 
-	rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano())
+	rec.Addresses = expire(rec.Addresses, s.clock.Now())
 	databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
 	return rec, nil
 }
@@ -139,10 +140,13 @@ loop:
 	for {
 		select {
 		case <-t.C:
+			log.Println("Flushing database")
 			if err := s.write(); err != nil {
 				log.Println("Error writing database:", err)
 			}
+			log.Println("Calculating statistics")
 			s.calculateStatistics()
+			log.Println("Finished calculating statistics")
 			t.Reset(s.flushInterval)
 
 		case <-ctx.Done():
@@ -155,10 +159,9 @@ loop:
 }
 
 func (s *inMemoryStore) calculateStatistics() {
-	t0 := time.Now()
-	nowNanos := t0.UnixNano()
-	cutoff24h := t0.Add(-24 * time.Hour).UnixNano()
-	cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano()
+	now := s.clock.Now()
+	cutoff24h := now.Add(-24 * time.Hour).UnixNano()
+	cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
 	current, currentIPv4, currentIPv6, last24h, last1w, errors := 0, 0, 0, 0, 0, 0
 
 	n := 0
@@ -168,15 +171,11 @@ func (s *inMemoryStore) calculateStatistics() {
 		}
 		n++
 
-		// If there are addresses that have not expired it's a current
-		// record, otherwise account it based on when it was last seen
-		// (last 24 hours or last week) or finally as inactice.
-		addrs := expire(rec.Addresses, nowNanos)
 		switch {
-		case len(addrs) > 0:
+		case len(rec.Addresses) > 0:
 			current++
 			seenIPv4, seenIPv6 := false, false
-			for _, addr := range addrs {
+			for _, addr := range rec.Addresses {
 				uri, err := url.Parse(addr.Address)
 				if err != nil {
 					continue
@@ -217,7 +216,7 @@ func (s *inMemoryStore) calculateStatistics() {
 	databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
 	databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
 	databaseKeys.WithLabelValues("error").Set(float64(errors))
-	databaseStatisticsSeconds.Set(time.Since(t0).Seconds())
+	databaseStatisticsSeconds.Set(time.Since(now).Seconds())
 }
 
 func (s *inMemoryStore) write() (err error) {
@@ -240,8 +239,8 @@ func (s *inMemoryStore) write() (err error) {
 
 	var buf []byte
 	var rangeErr error
-	now := s.clock.Now().UnixNano()
-	cutoff1w := s.clock.Now().Add(-7 * 24 * time.Hour).UnixNano()
+	now := s.clock.Now()
+	cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
 	n := 0
 	s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
 		if n%1000 == 0 {
@@ -308,32 +307,33 @@ func (s *inMemoryStore) write() (err error) {
 	return nil
 }
 
-func (s *inMemoryStore) read() error {
+func (s *inMemoryStore) read() (int, error) {
 	fd, err := os.Open(path.Join(s.dir, "records.db"))
 	if err != nil {
-		return err
+		return 0, err
 	}
 	defer fd.Close()
 
 	br := bufio.NewReader(fd)
 	var buf []byte
+	nr := 0
 	for {
 		var n uint32
 		if err := binary.Read(br, binary.BigEndian, &n); err != nil {
 			if errors.Is(err, io.EOF) {
 				break
 			}
-			return err
+			return nr, err
 		}
 		if int(n) > len(buf) {
 			buf = make([]byte, n)
 		}
 		if _, err := io.ReadFull(br, buf[:n]); err != nil {
-			return err
+			return nr, err
 		}
 		rec := ReplicationRecord{}
 		if err := rec.Unmarshal(buf[:n]); err != nil {
-			return err
+			return nr, err
 		}
 		key, err := protocol.DeviceIDFromBytes(rec.Key)
 		if err != nil {
@@ -349,8 +349,9 @@ func (s *inMemoryStore) read() error {
 			Addresses: rec.Addresses,
 			Seen:      rec.Seen,
 		})
+		nr++
 	}
-	return nil
+	return nr, nil
 }
 
 // merge returns the merged result of the two database records a and b. The
@@ -423,18 +424,15 @@ loop:
 // expire returns the list of addresses after removing expired entries.
 // Expiration happen in place, so the slice given as the parameter is
 // destroyed. Internal order is preserved.
-func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress {
-	i := 0
-	for i < len(addrs) {
-		if addrs[i].Expires < now {
-			copy(addrs[i:], addrs[i+1:])
-			addrs[len(addrs)-1] = DatabaseAddress{}
-			addrs = addrs[:len(addrs)-1]
-			continue
+func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
+	cutoff := now.UnixNano()
+	naddrs := addrs[:0]
+	for i := range addrs {
+		if addrs[i].Expires >= cutoff {
+			naddrs = append(naddrs, addrs[i])
 		}
-		i++
 	}
-	return addrs
+	return naddrs
 }
 
 func s3Upload(r io.Reader) error {

+ 1 - 1
cmd/stdiscosrv/database_test.go

@@ -160,7 +160,7 @@ func TestFilter(t *testing.T) {
 	}
 
 	for _, tc := range cases {
-		res := expire(tc.a, 10)
+		res := expire(tc.a, time.Unix(0, 10))
 		if fmt.Sprint(res) != fmt.Sprint(tc.b) {
 			t.Errorf("Incorrect result %v, expected %v", res, tc.b)
 		}