Просмотр исходного кода

Restore unified reader architecture per @RPRX request

Reverted timeout-based solution and restored the unified reader architecture:
- Each peer connection continuously reads and queues to dataChan
- Single unifiedReader() dispatcher matches data with read requests
- No blocking - all connections monitored simultaneously
- Addresses @RPRX's request for unified reader instead of timeout

Architecture benefits:
- True concurrent reading from all peer connections
- Clean separation between reading and dispatching
- No timeout delays or retry loops
- Scalable to any number of peers

Tests pass.

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 1 неделя назад
Родитель
Сommit
31d10f3544
2 измененных файлов с 119 добавлено и 25 удалено
  1. 112 24
      proxy/wireguard/bind.go
  2. 7 1
      proxy/wireguard/client.go

+ 112 - 24
proxy/wireguard/bind.go

@@ -6,7 +6,6 @@ import (
 	"net/netip"
 	"strconv"
 	"sync"
-	"time"
 
 	"golang.zx2c4.com/wireguard/conn"
 
@@ -125,6 +124,26 @@ type netBindClient struct {
 	ctx      context.Context
 	dialer   internet.Dialer
 	reserved []byte
+	
+	// Track all peer connections for unified reading
+	connMutex sync.RWMutex
+	conns     map[*netEndpoint]net.Conn
+	dataChan  chan *receivedData
+	closeChan chan struct{}
+	closeOnce sync.Once
+}
+
+const (
+	// Buffer size for dataChan - allows some buffering of received packets
+	// while dispatcher matches them with read requests
+	dataChannelBufferSize = 100
+)
+
+type receivedData struct {
+	data     []byte
+	n        int
+	endpoint *netEndpoint
+	err      error
 }
 
 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
@@ -134,45 +153,114 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	endpoint.conn = c
 
-	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
+	// Initialize channels on first connection
+	bind.connMutex.Lock()
+	if bind.conns == nil {
+		bind.conns = make(map[*netEndpoint]net.Conn)
+		bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
+		bind.closeChan = make(chan struct{})
+		
+		// Start unified reader dispatcher
+		go bind.unifiedReader()
+	}
+	bind.conns[endpoint] = c
+	bind.connMutex.Unlock()
+	
+	// Start a reader goroutine for this specific connection
+	go func(conn net.Conn, endpoint *netEndpoint) {
+		const maxPacketSize = 1500
 		for {
-			v, ok := <-readQueue
-			if !ok {
+			select {
+			case <-bind.closeChan:
 				return
+			default:
 			}
 			
-			// Set read deadline to prevent indefinite blocking
-			if conn, ok := c.(interface{ SetReadDeadline(time.Time) error }); ok {
-				conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
-			}
-			
-			i, err := c.Read(v.buff)
+			buf := make([]byte, maxPacketSize)
+			n, err := conn.Read(buf)
 			
-			// Clear read deadline
-			if conn, ok := c.(interface{ SetReadDeadline(time.Time) error }); ok {
-				conn.SetReadDeadline(time.Time{})
+			// Send only the valid data portion to dispatcher
+			dataToSend := buf
+			if n > 0 && n < len(buf) {
+				dataToSend = buf[:n]
 			}
-
-			if i > 3 {
-				v.buff[1] = 0
-				v.buff[2] = 0
-				v.buff[3] = 0
+			
+			// Send received data to dispatcher
+			select {
+			case bind.dataChan <- &receivedData{
+				data:     dataToSend,
+				n:        n,
+				endpoint: endpoint,
+				err:      err,
+			}:
+			case <-bind.closeChan:
+				return
 			}
-
-			v.bytes = i
-			v.endpoint = endpoint
-			v.err = err
-			v.waiter.Done()
+			
 			if err != nil {
+				bind.connMutex.Lock()
+				delete(bind.conns, endpoint)
 				endpoint.conn = nil
+				bind.connMutex.Unlock()
 				return
 			}
 		}
-	}(bind.readQueue, endpoint)
+	}(c, endpoint)
 
 	return nil
 }
 
+// unifiedReader dispatches received data to waiting read requests
+func (bind *netBindClient) unifiedReader() {
+	for {
+		select {
+		case data := <-bind.dataChan:
+			// Bounds check to prevent panic
+			if data.n > len(data.data) {
+				data.n = len(data.data)
+			}
+			
+			// Wait for a read request with timeout to prevent blocking forever
+			select {
+			case v := <-bind.readQueue:
+				// Copy data to request buffer
+				n := copy(v.buff, data.data[:data.n])
+				
+				// Clear reserved bytes if needed
+				if n > 3 {
+					v.buff[1] = 0
+					v.buff[2] = 0
+					v.buff[3] = 0
+				}
+				
+				v.bytes = n
+				v.endpoint = data.endpoint
+				v.err = data.err
+				v.waiter.Done()
+			case <-bind.closeChan:
+				return
+			}
+		case <-bind.closeChan:
+			return
+		}
+	}
+}
+
+// Close implements conn.Bind.Close for netBindClient
+func (bind *netBindClient) Close() error {
+	// Use sync.Once to prevent double-close panic
+	bind.closeOnce.Do(func() {
+		bind.connMutex.Lock()
+		if bind.closeChan != nil {
+			close(bind.closeChan)
+		}
+		bind.connMutex.Unlock()
+	})
+	
+	// Call parent Close
+	return bind.netBind.Close()
+}
+
 func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
 	var err error
 

+ 7 - 1
proxy/wireguard/client.go

@@ -114,6 +114,12 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
 	}
 
 	// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
+	// Set workers to number of peers if not explicitly configured
+	// This allows concurrent packet reception from multiple peers
+	workers := int(h.conf.NumWorkers)
+	if workers <= 0 && len(h.conf.Peers) > 0 {
+		workers = len(h.conf.Peers)
+	}
 	h.bind = &netBindClient{
 		netBind: netBind{
 			dns: h.dns,
@@ -121,7 +127,7 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
 				IPv4Enable: h.hasIPv4,
 				IPv6Enable: h.hasIPv6,
 			},
-			workers: int(h.conf.NumWorkers),
+			workers: workers,
 		},
 		ctx:      core.ToBackgroundDetachedContext(ctx),
 		dialer:   dialer,