Quellcode durchsuchen

Reverse: portal-worker should not be closed before making sure there is at least one other active worker (#4869)

patterniha vor 5 Monaten
Ursprung
Commit
b065595f58
3 geänderte Dateien mit 21 neuen und 9 gelöschten Zeilen
  1. 1 0
      app/reverse/config.go
  2. 12 6
      app/reverse/portal.go
  3. 8 3
      common/mux/client.go

+ 1 - 0
app/reverse/config.go

@@ -9,6 +9,7 @@ import (
 
 func (c *Control) FillInRandom() {
 	randomLength := dice.Roll(64)
+	randomLength++
 	c.Random = make([]byte, randomLength)
 	io.ReadFull(rand.Reader, c.Random)
 }

+ 12 - 6
app/reverse/portal.go

@@ -170,7 +170,7 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) {
 		if w.draining {
 			continue
 		}
-		if w.client.Closed() {
+		if w.IsFull() {
 			continue
 		}
 		if w.client.ActiveConnections() < minConn {
@@ -211,6 +211,7 @@ type PortalWorker struct {
 	writer   buf.Writer
 	reader   buf.Reader
 	draining bool
+	counter  uint32
 }
 
 func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
@@ -244,7 +245,7 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
 }
 
 func (w *PortalWorker) heartbeat() error {
-	if w.client.Closed() {
+	if w.Closed() {
 		return errors.New("client worker stopped")
 	}
 
@@ -260,16 +261,21 @@ func (w *PortalWorker) heartbeat() error {
 		msg.State = Control_DRAIN
 
 		defer func() {
+			w.client.GetTimer().Reset(time.Second * 16)
 			common.Close(w.writer)
 			common.Interrupt(w.reader)
 			w.writer = nil
 		}()
 	}
 
-	b, err := proto.Marshal(msg)
-	common.Must(err)
-	mb := buf.MergeBytes(nil, b)
-	return w.writer.WriteMultiBuffer(mb)
+	w.counter = (w.counter + 1) % 5
+	if w.draining || w.counter == 1 {
+		b, err := proto.Marshal(msg)
+		common.Must(err)
+		mb := buf.MergeBytes(nil, b)
+		return w.writer.WriteMultiBuffer(mb)
+	}
+	return nil
 }
 
 func (w *PortalWorker) IsFull() bool {

+ 8 - 3
common/mux/client.go

@@ -173,6 +173,7 @@ type ClientWorker struct {
 	sessionManager *SessionManager
 	link           transport.Link
 	done           *done.Instance
+	timer          *time.Ticker
 	strategy       ClientStrategy
 }
 
@@ -187,6 +188,7 @@ func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, er
 		sessionManager: NewSessionManager(),
 		link:           stream,
 		done:           done.New(),
+		timer:          time.NewTicker(time.Second * 16),
 		strategy:       s,
 	}
 
@@ -209,9 +211,12 @@ func (m *ClientWorker) Closed() bool {
 	return m.done.Done()
 }
 
+func (m *ClientWorker) GetTimer() *time.Ticker {
+	return m.timer
+}
+
 func (m *ClientWorker) monitor() {
-	timer := time.NewTicker(time.Second * 16)
-	defer timer.Stop()
+	defer m.timer.Stop()
 
 	for {
 		select {
@@ -220,7 +225,7 @@ func (m *ClientWorker) monitor() {
 			common.Close(m.link.Writer)
 			common.Interrupt(m.link.Reader)
 			return
-		case <-timer.C:
+		case <-m.timer.C:
 			size := m.sessionManager.Size()
 			if size == 0 && m.sessionManager.CloseIfNoSession() {
 				common.Must(m.done.Close())