Explorar o código

lib/connections: Slightly refactor limiter juggling

Two small behavior changes: don't "charge" the data to the global rate
limit until it's been accepted by the device specific limiter, and fix
the send/recv direction in the log print on per device rate limits.
Jakob Borg %!s(int64=7) %!d(string=hai) anos
pai
achega
c49d864f14
Modificáronse 2 ficheiros con 74 adicións e 54 borrados
  1. 9 9
      lib/config/config.go
  2. 65 45
      lib/connections/limiter.go

+ 9 - 9
lib/config/config.go

@@ -377,6 +377,15 @@ func (cfg *Configuration) clean() error {
 	return nil
 }
 
+// DeviceMap returns a map of device ID to device configuration for the given configuration.
+func (cfg *Configuration) DeviceMap() map[protocol.DeviceID]DeviceConfiguration {
+	m := make(map[protocol.DeviceID]DeviceConfiguration, len(cfg.Devices))
+	for _, dev := range cfg.Devices {
+		m[dev.DeviceID] = dev
+	}
+	return m
+}
+
 func convertV27V28(cfg *Configuration) {
 	// Show a notification about enabling filesystem watching
 	cfg.Options.UnackedNotificationIDs = append(cfg.Options.UnackedNotificationIDs, "fsWatcherNotification")
@@ -797,12 +806,3 @@ func filterURLSchemePrefix(addrs []string, prefix string) []string {
 	}
 	return addrs
 }
-
-// mapDeviceConfigs returns a map of device ID to device configuration for the given configuration.
-func (cfg *Configuration) DeviceMap() map[protocol.DeviceID]DeviceConfiguration {
-	m := make(map[protocol.DeviceID]DeviceConfiguration, len(cfg.Devices))
-	for _, dev := range cfg.Devices {
-		m[dev.DeviceID] = dev
-	}
-	return m
-}

+ 65 - 45
lib/connections/limiter.go

@@ -21,12 +21,17 @@ import (
 // limiter manages a read and write rate limit, reacting to config changes
 // as appropriate.
 type limiter struct {
+	mu                  sync.Mutex
 	write               *rate.Limiter
 	read                *rate.Limiter
 	limitsLAN           atomicBool
 	deviceReadLimiters  map[protocol.DeviceID]*rate.Limiter
 	deviceWriteLimiters map[protocol.DeviceID]*rate.Limiter
-	mu                  sync.Mutex
+}
+
+type waiter interface {
+	// This is the rate limiting operation
+	WaitN(ctx context.Context, n int) error
 }
 
 const limiterBurstSize = 4 * 128 << 10
@@ -96,7 +101,7 @@ func (lim *limiter) processDevicesConfigurationLocked(from, to config.Configurat
 				writeLimitStr = fmt.Sprintf("limit is %d KiB/s", dev.MaxSendKbps)
 			}
 
-			l.Infof("Device %s send rate %s, receive rate %s", dev.DeviceID, readLimitStr, writeLimitStr)
+			l.Infof("Device %s send rate %s, receive rate %s", dev.DeviceID, writeLimitStr, readLimitStr)
 		}
 	}
 
@@ -169,49 +174,76 @@ func (lim *limiter) String() string {
 	return "connections.limiter"
 }
 
-func (lim *limiter) getLimiters(remoteID protocol.DeviceID, c internalConn, isLAN bool) (io.Reader, io.Writer) {
+func (lim *limiter) getLimiters(remoteID protocol.DeviceID, rw io.ReadWriter, isLAN bool) (io.Reader, io.Writer) {
 	lim.mu.Lock()
-	wr := lim.newLimitedWriterLocked(remoteID, c, isLAN)
-	rd := lim.newLimitedReaderLocked(remoteID, c, isLAN)
+	wr := lim.newLimitedWriterLocked(remoteID, rw, isLAN)
+	rd := lim.newLimitedReaderLocked(remoteID, rw, isLAN)
 	lim.mu.Unlock()
 	return rd, wr
 }
 
 func (lim *limiter) newLimitedReaderLocked(remoteID protocol.DeviceID, r io.Reader, isLAN bool) io.Reader {
-	return &limitedReader{reader: r, limiter: lim, deviceLimiter: lim.getReadLimiterLocked(remoteID), isLAN: isLAN}
+	return &limitedReader{
+		reader:    r,
+		limitsLAN: &lim.limitsLAN,
+		waiter:    totalWaiter{lim.getReadLimiterLocked(remoteID), lim.read},
+		isLAN:     isLAN,
+	}
 }
 
 func (lim *limiter) newLimitedWriterLocked(remoteID protocol.DeviceID, w io.Writer, isLAN bool) io.Writer {
-	return &limitedWriter{writer: w, limiter: lim, deviceLimiter: lim.getWriteLimiterLocked(remoteID), isLAN: isLAN}
+	return &limitedWriter{
+		writer:    w,
+		limitsLAN: &lim.limitsLAN,
+		waiter:    totalWaiter{lim.getWriteLimiterLocked(remoteID), lim.write},
+		isLAN:     isLAN,
+	}
+}
+
+func (lim *limiter) getReadLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter {
+	return getRateLimiter(lim.deviceReadLimiters, deviceID)
+}
+
+func (lim *limiter) getWriteLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter {
+	return getRateLimiter(lim.deviceWriteLimiters, deviceID)
+}
+
+func getRateLimiter(m map[protocol.DeviceID]*rate.Limiter, deviceID protocol.DeviceID) *rate.Limiter {
+	limiter, ok := m[deviceID]
+	if !ok {
+		limiter = rate.NewLimiter(rate.Inf, limiterBurstSize)
+		m[deviceID] = limiter
+	}
+	return limiter
 }
 
 // limitedReader is a rate limited io.Reader
 type limitedReader struct {
-	reader        io.Reader
-	limiter       *limiter
-	deviceLimiter *rate.Limiter
-	isLAN         bool
+	reader    io.Reader
+	limitsLAN *atomicBool
+	waiter    waiter
+	isLAN     bool
 }
 
 func (r *limitedReader) Read(buf []byte) (int, error) {
 	n, err := r.reader.Read(buf)
-	if !r.isLAN || r.limiter.limitsLAN.get() {
-		take(r.limiter.read, r.deviceLimiter, n)
+	if !r.isLAN || r.limitsLAN.get() {
+		take(r.waiter, n)
 	}
 	return n, err
 }
 
 // limitedWriter is a rate limited io.Writer
 type limitedWriter struct {
-	writer        io.Writer
-	limiter       *limiter
-	deviceLimiter *rate.Limiter
-	isLAN         bool
+	writer    io.Writer
+	limitsLAN *atomicBool
+	waiter    waiter
+	isLAN     bool
 }
 
 func (w *limitedWriter) Write(buf []byte) (int, error) {
-	if !w.isLAN || w.limiter.limitsLAN.get() {
-		take(w.limiter.write, w.deviceLimiter, len(buf))
+	if !w.isLAN || w.limitsLAN.get() {
+		take(w.waiter, len(buf))
 	}
 	return w.writer.Write(buf)
 }
@@ -219,24 +251,21 @@ func (w *limitedWriter) Write(buf []byte) (int, error) {
 // take is a utility function to consume tokens from a overall rate.Limiter and deviceLimiter.
 // No call to WaitN can be larger than the limiter burst size so we split it up into
 // several calls when necessary.
-func take(overallLimiter, deviceLimiter *rate.Limiter, tokens int) {
+func take(waiter waiter, tokens int) {
 	if tokens < limiterBurstSize {
 		// This is the by far more common case so we get it out of the way
 		// early.
-		deviceLimiter.WaitN(context.TODO(), tokens)
-		overallLimiter.WaitN(context.TODO(), tokens)
+		waiter.WaitN(context.TODO(), tokens)
 		return
 	}
 
 	for tokens > 0 {
 		// Consume limiterBurstSize tokens at a time until we're done.
 		if tokens > limiterBurstSize {
-			deviceLimiter.WaitN(context.TODO(), limiterBurstSize)
-			overallLimiter.WaitN(context.TODO(), limiterBurstSize)
+			waiter.WaitN(context.TODO(), limiterBurstSize)
 			tokens -= limiterBurstSize
 		} else {
-			deviceLimiter.WaitN(context.TODO(), tokens)
-			overallLimiter.WaitN(context.TODO(), tokens)
+			waiter.WaitN(context.TODO(), tokens)
 			tokens = 0
 		}
 	}
@@ -256,25 +285,16 @@ func (b *atomicBool) get() bool {
 	return atomic.LoadInt32((*int32)(b)) != 0
 }
 
-// Utility functions for atomic operations on device limiters map
-func (lim *limiter) getWriteLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter {
-	limiter, ok := lim.deviceWriteLimiters[deviceID]
-
-	if !ok {
-		limiter = rate.NewLimiter(rate.Inf, limiterBurstSize)
-		lim.deviceWriteLimiters[deviceID] = limiter
-	}
-
-	return limiter
-}
+// totalWaiter waits for all of the waiters
+type totalWaiter []waiter
 
-func (lim *limiter) getReadLimiterLocked(deviceID protocol.DeviceID) *rate.Limiter {
-	limiter, ok := lim.deviceReadLimiters[deviceID]
-
-	if !ok {
-		limiter = rate.NewLimiter(rate.Inf, limiterBurstSize)
-		lim.deviceReadLimiters[deviceID] = limiter
+func (tw totalWaiter) WaitN(ctx context.Context, n int) error {
+	for _, w := range tw {
+		if err := w.WaitN(ctx, n); err != nil {
+			// error here is context cancellation, most likely, so we abort
+			// early
+			return err
+		}
 	}
-
-	return limiter
+	return nil
 }