|
|
@@ -40,6 +40,7 @@ import (
|
|
|
"tailscale.com/disco"
|
|
|
"tailscale.com/envknob"
|
|
|
"tailscale.com/metrics"
|
|
|
+ "tailscale.com/syncs"
|
|
|
"tailscale.com/types/key"
|
|
|
"tailscale.com/types/logger"
|
|
|
"tailscale.com/version"
|
|
|
@@ -1560,22 +1561,20 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
|
|
|
// Duplicate registration of same forwarder. Ignore.
|
|
|
return
|
|
|
}
|
|
|
- if m, ok := prev.(multiForwarder); ok {
|
|
|
- if _, ok := m[fwd]; ok {
|
|
|
+ if m, ok := prev.(*multiForwarder); ok {
|
|
|
+ if _, ok := m.all[fwd]; ok {
|
|
|
// Duplicate registration of same forwarder in set; ignore.
|
|
|
return
|
|
|
}
|
|
|
- m[fwd] = m.maxVal() + 1
|
|
|
+ m.add(fwd)
|
|
|
return
|
|
|
}
|
|
|
if prev != nil {
|
|
|
// Otherwise, the existing value is not a set,
|
|
|
// not a dup, and not local-only (nil) so make
|
|
|
- // it a set.
|
|
|
- fwd = multiForwarder{
|
|
|
- prev: 1, // existed 1st, higher priority
|
|
|
- fwd: 2, // the passed in fwd is in 2nd place
|
|
|
- }
|
|
|
+ // it a set. `prev` existed first, so will have higher
|
|
|
+ // priority.
|
|
|
+ fwd = newMultiForwarder(prev, fwd)
|
|
|
s.multiForwarderCreated.Add(1)
|
|
|
}
|
|
|
}
|
|
|
@@ -1591,19 +1590,14 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
|
|
|
if !ok {
|
|
|
return
|
|
|
}
|
|
|
- if m, ok := v.(multiForwarder); ok {
|
|
|
- if len(m) < 2 {
|
|
|
+ if m, ok := v.(*multiForwarder); ok {
|
|
|
+ if len(m.all) < 2 {
|
|
|
panic("unexpected")
|
|
|
}
|
|
|
- delete(m, fwd)
|
|
|
- // If fwd was in m and we no longer need to be a
|
|
|
- // multiForwarder, replace the entry with the
|
|
|
- // remaining PacketForwarder.
|
|
|
- if len(m) == 1 {
|
|
|
- var remain PacketForwarder
|
|
|
- for k := range m {
|
|
|
- remain = k
|
|
|
- }
|
|
|
+ if remain, isLast := m.deleteLocked(fwd); isLast {
|
|
|
+ // If fwd was in m and we no longer need to be a
|
|
|
+ // multiForwarder, replace the entry with the
|
|
|
+ // remaining PacketForwarder.
|
|
|
s.clientsMesh[dst] = remain
|
|
|
s.multiForwarderDeleted.Add(1)
|
|
|
}
|
|
|
@@ -1635,27 +1629,65 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
|
|
|
// client is. The map value is unique connection number; the lowest
|
|
|
// one has been seen the longest. It's used to make sure we forward
|
|
|
// packets consistently to the same node and don't pick randomly.
|
|
|
-type multiForwarder map[PacketForwarder]uint8
|
|
|
+type multiForwarder struct {
|
|
|
+ fwd syncs.AtomicValue[PacketForwarder] // preferred forwarder.
|
|
|
+ all map[PacketForwarder]uint8 // all forwarders, protected by s.mu.
|
|
|
+}
|
|
|
+
|
|
|
+// newMultiForwarder creates a new multiForwarder.
|
|
|
+// The first PacketForwarder passed to this function will be the preferred one.
|
|
|
+func newMultiForwarder(fwds ...PacketForwarder) *multiForwarder {
|
|
|
+ f := &multiForwarder{all: make(map[PacketForwarder]uint8)}
|
|
|
+ f.fwd.Store(fwds[0])
|
|
|
+ for idx, fwd := range fwds {
|
|
|
+ f.all[fwd] = uint8(idx)
|
|
|
+ }
|
|
|
+ return f
|
|
|
+}
|
|
|
|
|
|
-func (m multiForwarder) maxVal() (max uint8) {
|
|
|
- for _, v := range m {
|
|
|
+// add adds a new forwarder to the map with a connection number that
|
|
|
+// is higher than the existing ones.
|
|
|
+func (f *multiForwarder) add(fwd PacketForwarder) {
|
|
|
+ var max uint8
|
|
|
+ for _, v := range f.all {
|
|
|
if v > max {
|
|
|
max = v
|
|
|
}
|
|
|
}
|
|
|
- return
|
|
|
+ f.all[fwd] = max + 1
|
|
|
}
|
|
|
|
|
|
-func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
|
|
|
- var fwd PacketForwarder
|
|
|
- var lowest uint8
|
|
|
- for k, v := range m {
|
|
|
- if fwd == nil || v < lowest {
|
|
|
- fwd = k
|
|
|
- lowest = v
|
|
|
+// deleteLocked removes a packet forwarder from the map. It expects Server.mu to be held.
|
|
|
+// If only one forwarder remains after the removal, it will be returned alongside a `true` boolean value.
|
|
|
+func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, isLast bool) {
|
|
|
+ delete(f.all, fwd)
|
|
|
+
|
|
|
+ if fwd == f.fwd.Load() {
|
|
|
+ // The preferred forwarder has been removed, choose a new one
|
|
|
+ // based on the lowest index.
|
|
|
+ var lowestfwd PacketForwarder
|
|
|
+ var lowest uint8
|
|
|
+ for k, v := range f.all {
|
|
|
+ if lowestfwd == nil || v < lowest {
|
|
|
+ lowestfwd = k
|
|
|
+ lowest = v
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if lowestfwd != nil {
|
|
|
+ f.fwd.Store(lowestfwd)
|
|
|
}
|
|
|
}
|
|
|
- return fwd.ForwardPacket(src, dst, payload)
|
|
|
+
|
|
|
+ if len(f.all) == 1 {
|
|
|
+ for k := range f.all {
|
|
|
+ return k, true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil, false
|
|
|
+}
|
|
|
+
|
|
|
+func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
|
|
|
+ return f.fwd.Load().ForwardPacket(src, dst, payload)
|
|
|
}
|
|
|
|
|
|
func (s *Server) expVarFunc(f func() any) expvar.Func {
|