Browse Source

fix(strelaysrv): make the session limiter session-dependent (fixes #10072) (#10073)

### Purpose

Make the session limiter only apply to current session.

### Testing

Relay 2 or more sessions and check if the sum of the connection speed
can exceed the specified per-session rate.

2 sessions (-global-rate=50000000 and -per-session-rate=6250000):


![图片](https://github.com/user-attachments/assets/133e531a-ed49-4890-aef7-821c628bcfc8)

1 session (-global-rate=50000000 and -per-session-rate=6250000):


![图片](https://github.com/user-attachments/assets/ac89ea53-2d8e-4347-9bbc-4780d85e38d7)
domain 6 months ago
parent
commit
0bf21d9db2
3 changed files with 8 additions and 6 deletions
  1. 1 1
      cmd/strelaysrv/listener.go
  2. 0 4
      cmd/strelaysrv/main.go
  3. 7 1
      cmd/strelaysrv/session.go

+ 1 - 1
cmd/strelaysrv/listener.go

@@ -184,7 +184,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config, token strin
 					continue
 				}
 				// requestedPeer is the server, id is the client
-				ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter)
+				ses := newSession(requestedPeer, id, sessionLimitBps, globalLimiter)
 
 				go ses.Serve()
 

+ 0 - 4
cmd/strelaysrv/main.go

@@ -51,7 +51,6 @@ var (
 	globalLimitBps    int
 	overLimit         atomic.Bool
 	descriptorLimit   int64
-	sessionLimiter    *rate.Limiter
 	globalLimiter     *rate.Limiter
 	networkBufferSize int
 
@@ -228,9 +227,6 @@ func main() {
 		}
 	}
 
-	if sessionLimitBps > 0 {
-		sessionLimiter = rate.NewLimiter(rate.Limit(sessionLimitBps), 2*sessionLimitBps)
-	}
 	if globalLimitBps > 0 {
 		globalLimiter = rate.NewLimiter(rate.Limit(globalLimitBps), 2*globalLimitBps)
 	}

+ 7 - 1
cmd/strelaysrv/session.go

@@ -27,7 +27,7 @@ var (
 	bytesProxied    atomic.Int64
 )
 
-func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *rate.Limiter) *session {
+func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionLimitBps int, globalRateLimit *rate.Limiter) *session {
 	serverkey := make([]byte, 32)
 	_, err := rand.Read(serverkey)
 	if err != nil {
@@ -40,12 +40,17 @@ func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit,
 		return nil
 	}
 
+	var sessionRateLimit *rate.Limiter
+	if sessionLimitBps > 0 {
+		sessionRateLimit = rate.NewLimiter(rate.Limit(sessionLimitBps), 2*sessionLimitBps)
+	}
 	ses := &session{
 		serverkey: serverkey,
 		serverid:  serverid,
 		clientkey: clientkey,
 		clientid:  clientid,
 		rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
+		limiter:   sessionRateLimit,
 		connsChan: make(chan net.Conn),
 		conns:     make([]net.Conn, 0, 2),
 	}
@@ -109,6 +114,7 @@ type session struct {
 	clientid  syncthingprotocol.DeviceID
 
 	rateLimit func(bytes int)
+	limiter   *rate.Limiter
 
 	connsChan chan net.Conn
 	conns     []net.Conn