Quellcode durchsuchen

wgengine/magicsock: add an addrLatency type to combine an IPPort+time.Duration

Updates #1566 (but no behavior changes as of this change)

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick vor 5 Jahren
Ursprung
Commit
9643d8b34d
2 geänderte Dateien mit 64 neuen und 17 gelöschten Zeilen
  1. 34 17
      wgengine/magicsock/magicsock.go
  2. 30 0
      wgengine/magicsock/magicsock_test.go

+ 34 - 17
wgengine/magicsock/magicsock.go

@@ -3079,10 +3079,9 @@ type discoEndpoint struct {
 	lastFullPing   time.Time      // last time we pinged all endpoints
 	derpAddr       netaddr.IPPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients)
 
-	bestAddr           netaddr.IPPort // best non-DERP path; zero if none
-	bestAddrLatency    time.Duration
-	bestAddrAt         time.Time // time best address re-confirmed
-	trustBestAddrUntil time.Time // time when bestAddr expires
+	bestAddr           addrLatency // best non-DERP path; zero if none
+	bestAddrAt         time.Time   // time best address re-confirmed
+	trustBestAddrUntil time.Time   // time when bestAddr expires
 	sentPing           map[stun.TxID]sentPing
 	endpointState      map[netaddr.IPPort]*endpointState
 	isCallMeMaybeEP    map[netaddr.IPPort]bool
@@ -3187,8 +3186,8 @@ func (st *endpointState) shouldDeleteLocked() bool {
 
 func (de *discoEndpoint) deleteEndpointLocked(ep netaddr.IPPort) {
 	delete(de.endpointState, ep)
-	if de.bestAddr == ep {
-		de.bestAddr = netaddr.IPPort{}
+	if de.bestAddr.IPPort == ep {
+		de.bestAddr = addrLatency{}
 	}
 }
 
@@ -3256,7 +3255,7 @@ func (de *discoEndpoint) DstToBytes() []byte  { return packIPPort(de.fakeWGAddr)
 //
 // de.mu must be held.
 func (de *discoEndpoint) addrForSendLocked(now time.Time) (udpAddr, derpAddr netaddr.IPPort) {
-	udpAddr = de.bestAddr
+	udpAddr = de.bestAddr.IPPort
 	if udpAddr.IsZero() || now.After(de.trustBestAddrUntil) {
 		// We had a bestAddr but it expired so send both to it
 		// and DERP.
@@ -3309,7 +3308,7 @@ func (de *discoEndpoint) wantFullPingLocked(now time.Time) bool {
 	if now.After(de.trustBestAddrUntil) {
 		return true
 	}
-	if de.bestAddrLatency <= goodEnoughLatency {
+	if de.bestAddr.latency <= goodEnoughLatency {
 		return false
 	}
 	if now.Sub(de.lastFullPing) >= upgradeInterval {
@@ -3641,20 +3640,39 @@ func (de *discoEndpoint) handlePongConnLocked(m *disco.Pong, src netaddr.IPPort)
 	// Promote this pong response to our current best address if it's lower latency.
 	// TODO(bradfitz): decide how latency vs. preference order affects decision
 	if !isDerp {
-		if de.bestAddr.IsZero() || latency < de.bestAddrLatency {
-			if de.bestAddr != sp.to {
-				de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to)
-				de.bestAddr = sp.to
-			}
+		thisPong := addrLatency{sp.to, latency}
+		if betterAddr(thisPong, de.bestAddr) {
+			de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to)
+			de.bestAddr = thisPong
 		}
-		if de.bestAddr == sp.to {
-			de.bestAddrLatency = latency
+		if de.bestAddr.IPPort == thisPong.IPPort {
+			de.bestAddr.latency = latency
 			de.bestAddrAt = now
 			de.trustBestAddrUntil = now.Add(trustUDPAddrDuration)
 		}
 	}
 }
 
+// addrLatency is an IPPort with an associated latency.
+type addrLatency struct {
+	netaddr.IPPort
+	latency time.Duration
+}
+
+// betterAddr reports whether a is a better addr to use than b.
+func betterAddr(a, b addrLatency) bool {
+	if a.IPPort == b.IPPort {
+		return false
+	}
+	if b.IsZero() {
+		return true
+	}
+	if a.IsZero() {
+		return false
+	}
+	return a.latency < b.latency
+}
+
 // discoEndpoint.mu must be held.
 func (st *endpointState) addPongReplyLocked(r pongReply) {
 	if n := len(st.recentPongs); n < pongHistoryCount {
@@ -3761,8 +3779,7 @@ func (de *discoEndpoint) stopAndReset() {
 	// state isn't a mix of before & after two sessions.
 	de.lastSend = time.Time{}
 	de.lastFullPing = time.Time{}
-	de.bestAddr = netaddr.IPPort{}
-	de.bestAddrLatency = 0
+	de.bestAddr = addrLatency{}
 	de.bestAddrAt = time.Time{}
 	de.trustBestAddrUntil = time.Time{}
 	for _, es := range de.endpointState {

+ 30 - 0
wgengine/magicsock/magicsock_test.go

@@ -1851,3 +1851,33 @@ func TestStringSetsEqual(t *testing.T) {
 	}
 
 }
+
+func TestBetterAddr(t *testing.T) {
+	const ms = time.Millisecond
+	al := func(ipps string, d time.Duration) addrLatency {
+		return addrLatency{netaddr.MustParseIPPort(ipps), d}
+	}
+	zero := addrLatency{}
+	tests := []struct {
+		a, b addrLatency
+		want bool
+	}{
+		{a: zero, b: zero, want: false},
+		{a: al("10.0.0.2:123", 5*ms), b: zero, want: true},
+		{a: zero, b: al("10.0.0.2:123", 5*ms), want: false},
+		{a: al("10.0.0.2:123", 5*ms), b: al("1.2.3.4:555", 6*ms), want: true},
+		{a: al("10.0.0.2:123", 5*ms), b: al("10.0.0.2:123", 10*ms), want: false}, // same IPPort
+	}
+	for _, tt := range tests {
+		got := betterAddr(tt.a, tt.b)
+		if got != tt.want {
+			t.Errorf("betterAddr(%+v, %+v) = %v; want %v", tt.a, tt.b, got, tt.want)
+			continue
+		}
+		gotBack := betterAddr(tt.b, tt.a)
+		if got && gotBack {
+			t.Errorf("betterAddr(%+v, %+v) and betterAddr(%+v, %+v) both unexpectedly true", tt.a, tt.b, tt.b, tt.a)
+		}
+	}
+
+}