|
|
@@ -6,7 +6,6 @@ import (
|
|
|
"net/netip"
|
|
|
"strconv"
|
|
|
"sync"
|
|
|
- "time"
|
|
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
|
|
|
|
@@ -125,6 +124,26 @@ type netBindClient struct {
|
|
|
ctx context.Context
|
|
|
dialer internet.Dialer
|
|
|
reserved []byte
|
|
|
+
|
|
|
+ // Track all peer connections for unified reading
|
|
|
+ connMutex sync.RWMutex
|
|
|
+ conns map[*netEndpoint]net.Conn
|
|
|
+ dataChan chan *receivedData
|
|
|
+ closeChan chan struct{}
|
|
|
+ closeOnce sync.Once
|
|
|
+}
|
|
|
+
|
|
|
+const (
|
|
|
+ // Buffer size for dataChan - allows some buffering of received packets
|
|
|
+ // while dispatcher matches them with read requests
|
|
|
+ dataChannelBufferSize = 100
|
|
|
+)
|
|
|
+
|
|
|
+type receivedData struct {
|
|
|
+ data []byte
|
|
|
+ n int
|
|
|
+ endpoint *netEndpoint
|
|
|
+ err error
|
|
|
}
|
|
|
|
|
|
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
@@ -134,45 +153,114 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
}
|
|
|
endpoint.conn = c
|
|
|
|
|
|
- go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
|
|
|
+ // Initialize channels on first connection
|
|
|
+ bind.connMutex.Lock()
|
|
|
+ if bind.conns == nil {
|
|
|
+ bind.conns = make(map[*netEndpoint]net.Conn)
|
|
|
+ bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
|
|
|
+ bind.closeChan = make(chan struct{})
|
|
|
+
|
|
|
+ // Start unified reader dispatcher
|
|
|
+ go bind.unifiedReader()
|
|
|
+ }
|
|
|
+ bind.conns[endpoint] = c
|
|
|
+ bind.connMutex.Unlock()
|
|
|
+
|
|
|
+ // Start a reader goroutine for this specific connection
|
|
|
+ go func(conn net.Conn, endpoint *netEndpoint) {
|
|
|
+ const maxPacketSize = 1500
|
|
|
for {
|
|
|
- v, ok := <-readQueue
|
|
|
- if !ok {
|
|
|
+ select {
|
|
|
+ case <-bind.closeChan:
|
|
|
return
|
|
|
+ default:
|
|
|
}
|
|
|
|
|
|
- // Set read deadline to prevent indefinite blocking
|
|
|
- if conn, ok := c.(interface{ SetReadDeadline(time.Time) error }); ok {
|
|
|
- conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
|
|
- }
|
|
|
-
|
|
|
- i, err := c.Read(v.buff)
|
|
|
+ buf := make([]byte, maxPacketSize)
|
|
|
+ n, err := conn.Read(buf)
|
|
|
|
|
|
- // Clear read deadline
|
|
|
- if conn, ok := c.(interface{ SetReadDeadline(time.Time) error }); ok {
|
|
|
- conn.SetReadDeadline(time.Time{})
|
|
|
+ // Send only the valid data portion to dispatcher
|
|
|
+ dataToSend := buf
|
|
|
+ if n > 0 && n < len(buf) {
|
|
|
+ dataToSend = buf[:n]
|
|
|
}
|
|
|
-
|
|
|
- if i > 3 {
|
|
|
- v.buff[1] = 0
|
|
|
- v.buff[2] = 0
|
|
|
- v.buff[3] = 0
|
|
|
+
|
|
|
+ // Send received data to dispatcher
|
|
|
+ select {
|
|
|
+ case bind.dataChan <- &receivedData{
|
|
|
+ data: dataToSend,
|
|
|
+ n: n,
|
|
|
+ endpoint: endpoint,
|
|
|
+ err: err,
|
|
|
+ }:
|
|
|
+ case <-bind.closeChan:
|
|
|
+ return
|
|
|
}
|
|
|
-
|
|
|
- v.bytes = i
|
|
|
- v.endpoint = endpoint
|
|
|
- v.err = err
|
|
|
- v.waiter.Done()
|
|
|
+
|
|
|
if err != nil {
|
|
|
+ bind.connMutex.Lock()
|
|
|
+ delete(bind.conns, endpoint)
|
|
|
endpoint.conn = nil
|
|
|
+ bind.connMutex.Unlock()
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- }(bind.readQueue, endpoint)
|
|
|
+ }(c, endpoint)
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// unifiedReader dispatches received data to waiting read requests
|
|
|
+func (bind *netBindClient) unifiedReader() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case data := <-bind.dataChan:
|
|
|
+ // Bounds check to prevent panic
|
|
|
+ if data.n > len(data.data) {
|
|
|
+ data.n = len(data.data)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait for a read request with timeout to prevent blocking forever
|
|
|
+ select {
|
|
|
+ case v := <-bind.readQueue:
|
|
|
+ // Copy data to request buffer
|
|
|
+ n := copy(v.buff, data.data[:data.n])
|
|
|
+
|
|
|
+ // Clear reserved bytes if needed
|
|
|
+ if n > 3 {
|
|
|
+ v.buff[1] = 0
|
|
|
+ v.buff[2] = 0
|
|
|
+ v.buff[3] = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ v.bytes = n
|
|
|
+ v.endpoint = data.endpoint
|
|
|
+ v.err = data.err
|
|
|
+ v.waiter.Done()
|
|
|
+ case <-bind.closeChan:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ case <-bind.closeChan:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Close implements conn.Bind.Close for netBindClient
|
|
|
+func (bind *netBindClient) Close() error {
|
|
|
+ // Use sync.Once to prevent double-close panic
|
|
|
+ bind.closeOnce.Do(func() {
|
|
|
+ bind.connMutex.Lock()
|
|
|
+ if bind.closeChan != nil {
|
|
|
+ close(bind.closeChan)
|
|
|
+ }
|
|
|
+ bind.connMutex.Unlock()
|
|
|
+ })
|
|
|
+
|
|
|
+ // Call parent Close
|
|
|
+ return bind.netBind.Close()
|
|
|
+}
|
|
|
+
|
|
|
func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
|
|
var err error
|
|
|
|