浏览代码

Merge branch 'infrastructure'

* infrastructure:
  chore(stdiscosrv): ensure incoming addresses are sorted and unique
  chore(stdiscosrv): use zero-allocation merge in the common case
  chore(stdiscosrv): properly clean out old addresses from memory
  chore(stdiscosrv): calculate IPv6 GUA
Jakob Borg 1 年之前
父节点
当前提交
0343bca257
共有 3 个文件被更改,包括 148 次插入67 次删除
  1. 5 4
      cmd/stdiscosrv/apisrv.go
  2. 61 63
      cmd/stdiscosrv/database.go
  3. 82 0
      cmd/stdiscosrv/database_test.go

+ 5 - 4
cmd/stdiscosrv/apisrv.go

@@ -307,16 +307,17 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string)
 	now := time.Now()
 	expire := now.Add(addressExpiryTime).UnixNano()
 
+	// The address slice must always be sorted for database merges to work
+	// properly.
+	slices.Sort(addresses)
+	addresses = slices.Compact(addresses)
+
 	dbAddrs := make([]DatabaseAddress, len(addresses))
 	for i := range addresses {
 		dbAddrs[i].Address = addresses[i]
 		dbAddrs[i].Expires = expire
 	}
 
-	// The address slice must always be sorted for database merges to work
-	// properly.
-	slices.SortFunc(dbAddrs, DatabaseAddress.Cmp)
-
 	seen := now.UnixNano()
 	if s.repl != nil {
 		s.repl.send(&deviceID, dbAddrs, seen)

+ 61 - 63
cmd/stdiscosrv/database.go

@@ -78,7 +78,7 @@ func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *in
 		log.Println("Error reading database:", err)
 	}
 	log.Printf("Read %d records from database", nr)
-	s.calculateStatistics()
+	s.expireAndCalculateStatistics()
 	return s
 }
 
@@ -99,7 +99,7 @@ func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, s
 	}
 
 	oldRec, _ := s.m.Load(*key)
-	newRec = merge(newRec, oldRec)
+	newRec = merge(oldRec, newRec)
 	s.m.Store(*key, newRec)
 
 	databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
@@ -126,19 +126,20 @@ func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
 }
 
 func (s *inMemoryStore) Serve(ctx context.Context) error {
-	t := time.NewTimer(s.flushInterval)
-	defer t.Stop()
-
 	if s.flushInterval <= 0 {
-		t.Stop()
+		<-ctx.Done()
+		return nil
 	}
 
+	t := time.NewTimer(s.flushInterval)
+	defer t.Stop()
+
 loop:
 	for {
 		select {
 		case <-t.C:
 			log.Println("Calculating statistics")
-			s.calculateStatistics()
+			s.expireAndCalculateStatistics()
 			log.Println("Flushing database")
 			if err := s.write(); err != nil {
 				log.Println("Error writing database:", err)
@@ -155,11 +156,11 @@ loop:
 	return s.write()
 }
 
-func (s *inMemoryStore) calculateStatistics() {
+func (s *inMemoryStore) expireAndCalculateStatistics() {
 	now := s.clock.Now()
 	cutoff24h := now.Add(-24 * time.Hour).UnixNano()
 	cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
-	current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0
+	current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0
 
 	n := 0
 	s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
@@ -169,17 +170,31 @@ func (s *inMemoryStore) calculateStatistics() {
 		n++
 
 		addresses := expire(rec.Addresses, now)
+		if len(addresses) == 0 {
+			rec.Addresses = nil
+			s.m.Store(key, rec)
+		} else if len(addresses) != len(rec.Addresses) {
+			rec.Addresses = addresses
+			s.m.Store(key, rec)
+		}
+
 		switch {
-		case len(addresses) > 0:
+		case len(rec.Addresses) > 0:
 			current++
-			seenIPv4, seenIPv6 := false, false
+			seenIPv4, seenIPv6, seenIPv6GUA := false, false, false
 			for _, addr := range rec.Addresses {
+				// We do fast and loose matching on strings here instead of
+				// parsing the address and the IP and doing "proper" checks,
+				// to keep things fast and generate less garbage.
 				if strings.Contains(addr.Address, "[") {
 					seenIPv6 = true
+					if strings.Contains(addr.Address, "[2") {
+						seenIPv6GUA = true
+					}
 				} else {
 					seenIPv4 = true
 				}
-				if seenIPv4 && seenIPv6 {
+				if seenIPv4 && seenIPv6 && seenIPv6GUA {
 					break
 				}
 			}
@@ -189,6 +204,9 @@ func (s *inMemoryStore) calculateStatistics() {
 			if seenIPv6 {
 				currentIPv6++
 			}
+			if seenIPv6GUA {
+				currentIPv6GUA++
+			}
 		case rec.Seen > cutoff24h:
 			last24h++
 		case rec.Seen > cutoff1w:
@@ -203,6 +221,7 @@ func (s *inMemoryStore) calculateStatistics() {
 	databaseKeys.WithLabelValues("current").Set(float64(current))
 	databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
 	databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
+	databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA))
 	databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
 	databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
 	databaseStatisticsSeconds.Set(time.Since(now).Seconds())
@@ -331,6 +350,7 @@ func (s *inMemoryStore) read() (int, error) {
 		}
 
 		slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
+		rec.Addresses = slices.CompactFunc(rec.Addresses, DatabaseAddress.Equal)
 		s.m.Store(key, DatabaseRecord{
 			Addresses: expire(rec.Addresses, s.clock.Now()),
 			Seen:      rec.Seen,
@@ -342,69 +362,36 @@ func (s *inMemoryStore) read() (int, error) {
 
 // merge returns the merged result of the two database records a and b. The
 // result is the union of the two address sets, with the newer expiry time
-// chosen for any duplicates.
+// chosen for any duplicates. The address list in a is overwritten and
+// reused for the result.
 func merge(a, b DatabaseRecord) DatabaseRecord {
 	// Both lists must be sorted for this to work.
 
-	res := DatabaseRecord{
-		Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))),
-		Seen:      a.Seen,
-	}
-	if b.Seen > a.Seen {
-		res.Seen = b.Seen
-	}
+	a.Seen = max(a.Seen, b.Seen)
 
 	aIdx := 0
 	bIdx := 0
-	aAddrs := a.Addresses
-	bAddrs := b.Addresses
-loop:
-	for {
-		switch {
-		case aIdx == len(aAddrs) && bIdx == len(bAddrs):
-			// both lists are exhausted, we are done
-			break loop
-
-		case aIdx == len(aAddrs):
-			// a is exhausted, pick from b and continue
-			res.Addresses = append(res.Addresses, bAddrs[bIdx])
-			bIdx++
-			continue
-
-		case bIdx == len(bAddrs):
-			// b is exhausted, pick from a and continue
-			res.Addresses = append(res.Addresses, aAddrs[aIdx])
-			aIdx++
-			continue
-		}
-
-		// We have values left on both sides.
-		aVal := aAddrs[aIdx]
-		bVal := bAddrs[bIdx]
-
-		switch {
-		case aVal.Address == bVal.Address:
-			// update for same address, pick newer
-			if aVal.Expires > bVal.Expires {
-				res.Addresses = append(res.Addresses, aVal)
-			} else {
-				res.Addresses = append(res.Addresses, bVal)
-			}
+	for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) {
+		switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) {
+		case 0:
+			// a == b, choose the newer expiry time
+			a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires)
 			aIdx++
 			bIdx++
-
-		case aVal.Address < bVal.Address:
-			// a is smallest, pick it and continue
-			res.Addresses = append(res.Addresses, aVal)
+		case -1:
+			// a < b, keep a and move on
 			aIdx++
-
-		default:
-			// b is smallest, pick it and continue
-			res.Addresses = append(res.Addresses, bVal)
+		case 1:
+			// a > b, insert b before a
+			a.Addresses = append(a.Addresses[:aIdx], append([]DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...)
 			bIdx++
 		}
 	}
-	return res
+	if bIdx < len(b.Addresses) {
+		a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...)
+	}
+
+	return a
 }
 
 // expire returns the list of addresses after removing expired entries.
@@ -414,10 +401,17 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
 	cutoff := now.UnixNano()
 	naddrs := addrs[:0]
 	for i := range addrs {
+		if i > 0 && addrs[i].Address == addrs[i-1].Address {
+			// Skip duplicates
+			continue
+		}
 		if addrs[i].Expires >= cutoff {
 			naddrs = append(naddrs, addrs[i])
 		}
 	}
+	if len(naddrs) == 0 {
+		return nil
+	}
 	return naddrs
 }
 
@@ -427,3 +421,7 @@ func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
 	}
 	return cmp.Compare(d.Expires, other.Expires)
 }
+
+func (d DatabaseAddress) Equal(other DatabaseAddress) bool {
+	return d.Address == other.Address
+}

+ 82 - 0
cmd/stdiscosrv/database_test.go

@@ -167,6 +167,88 @@ func TestFilter(t *testing.T) {
 	}
 }
 
+func TestMerge(t *testing.T) {
+	cases := []struct {
+		a, b, res []DatabaseAddress
+	}{
+		{nil, nil, nil},
+		{
+			nil,
+			[]DatabaseAddress{{Address: "a", Expires: 10}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}},
+		},
+		{
+			nil,
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}},
+			[]DatabaseAddress{{Address: "a", Expires: 15}},
+			[]DatabaseAddress{{Address: "a", Expires: 15}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}},
+			[]DatabaseAddress{{Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
+		},
+		{
+			[]DatabaseAddress{{Address: "y", Expires: 10}, {Address: "z", Expires: 10}},
+			[]DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}},
+			[]DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}, {Address: "y", Expires: 10}, {Address: "z", Expires: 10}},
+		},
+		{
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "d", Expires: 10}},
+			[]DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}},
+			[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}, {Address: "d", Expires: 10}},
+		},
+	}
+
+	for _, tc := range cases {
+		rec := merge(DatabaseRecord{Addresses: tc.a}, DatabaseRecord{Addresses: tc.b})
+		if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) {
+			t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res)
+		}
+		rec = merge(DatabaseRecord{Addresses: tc.b}, DatabaseRecord{Addresses: tc.a})
+		if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) {
+			t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res)
+		}
+	}
+}
+
+func BenchmarkMergeEqual(b *testing.B) {
+	for i := 0; i < b.N; i++ {
+		ar := []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}
+		br := []DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 10}}
+		res := merge(DatabaseRecord{Addresses: ar}, DatabaseRecord{Addresses: br})
+		if len(res.Addresses) != 2 {
+			b.Fatal("wrong length")
+		}
+		if res.Addresses[0].Address != "a" || res.Addresses[1].Address != "b" {
+			b.Fatal("wrong address")
+		}
+		if res.Addresses[0].Expires != 15 || res.Addresses[1].Expires != 15 {
+			b.Fatal("wrong expiry")
+		}
+	}
+	b.ReportAllocs() // should be zero per operation
+}
+
 type testClock struct {
 	now time.Time
 }