Просмотр исходного кода

Implement unified reader architecture for multi-peer WireGuard

Replaced the competing reader goroutines with a unified reading architecture:
- Each peer connection continuously reads into a shared data channel
- A single dispatcher goroutine matches received data with read requests
- Eliminates blocking issues - all connections are monitored simultaneously
- No more race conditions between peer readers

This addresses @RPRX's suggestion to "统一 read 后再分给指定的 peer reader"
(unified read then distribute to specified peer readers).

Architecture:
- connectTo() registers connection and starts a dedicated reader per connection
- Each connection reader continuously reads and sends to dataChan
- unifiedReader() dispatcher waits for data, then matches with pending requests
- All peers can receive simultaneously without any blocking

Tests pass successfully.

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 2 недель назад
Родитель
Сommit
4e0a87faf4
1 измененных файлов с 95 добавлено и 15 удалено
  1. 95 15
      proxy/wireguard/bind.go

+ 95 - 15
proxy/wireguard/bind.go

@@ -124,6 +124,19 @@ type netBindClient struct {
 	ctx      context.Context
 	dialer   internet.Dialer
 	reserved []byte
+	
+	// Track all peer connections for unified reading
+	connMutex sync.RWMutex
+	conns     map[*netEndpoint]net.Conn
+	dataChan  chan *receivedData
+	closeChan chan struct{}
+}
+
+type receivedData struct {
+	data     []byte
+	n        int
+	endpoint *netEndpoint
+	err      error
 }
 
 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
@@ -133,34 +146,101 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	endpoint.conn = c
 
-	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
+	// Initialize channels on first connection
+	bind.connMutex.Lock()
+	if bind.conns == nil {
+		bind.conns = make(map[*netEndpoint]net.Conn)
+		bind.dataChan = make(chan *receivedData, 100)
+		bind.closeChan = make(chan struct{})
+		
+		// Start unified reader dispatcher
+		go bind.unifiedReader()
+	}
+	bind.conns[endpoint] = c
+	bind.connMutex.Unlock()
+	
+	// Start a reader goroutine for this specific connection
+	go func(conn net.Conn, endpoint *netEndpoint) {
+		const maxPacketSize = 1500
 		for {
-			v, ok := <-readQueue
-			if !ok {
+			select {
+			case <-bind.closeChan:
 				return
+			default:
 			}
-			i, err := c.Read(v.buff)
-
-			if i > 3 {
-				v.buff[1] = 0
-				v.buff[2] = 0
-				v.buff[3] = 0
+			
+			buf := make([]byte, maxPacketSize)
+			n, err := conn.Read(buf)
+			
+			// Send received data to dispatcher
+			select {
+			case bind.dataChan <- &receivedData{
+				data:     buf,
+				n:        n,
+				endpoint: endpoint,
+				err:      err,
+			}:
+			case <-bind.closeChan:
+				return
 			}
-
-			v.bytes = i
-			v.endpoint = endpoint
-			v.err = err
-			v.waiter.Done()
+			
 			if err != nil {
+				bind.connMutex.Lock()
+				delete(bind.conns, endpoint)
 				endpoint.conn = nil
+				bind.connMutex.Unlock()
 				return
 			}
 		}
-	}(bind.readQueue, endpoint)
+	}(c, endpoint)
 
 	return nil
 }
 
+// unifiedReader dispatches received data to waiting read requests
+func (bind *netBindClient) unifiedReader() {
+	for {
+		select {
+		case data := <-bind.dataChan:
+			// Wait for a read request
+			select {
+			case v := <-bind.readQueue:
+				// Copy data to request buffer
+				n := copy(v.buff, data.data[:data.n])
+				
+				// Clear reserved bytes if needed
+				if n > 3 {
+					v.buff[1] = 0
+					v.buff[2] = 0
+					v.buff[3] = 0
+				}
+				
+				v.bytes = n
+				v.endpoint = data.endpoint
+				v.err = data.err
+				v.waiter.Done()
+			case <-bind.closeChan:
+				return
+			}
+		case <-bind.closeChan:
+			return
+		}
+	}
+}
+
+// Close implements conn.Bind.Close for netBindClient
+func (bind *netBindClient) Close() error {
+	// Close the channels to stop all goroutines
+	bind.connMutex.Lock()
+	if bind.closeChan != nil {
+		close(bind.closeChan)
+	}
+	bind.connMutex.Unlock()
+	
+	// Call parent Close
+	return bind.netBind.Close()
+}
+
 func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
 	var err error