|
|
@@ -25,13 +25,23 @@ type netReadInfo struct {
|
|
|
err error
|
|
|
}
|
|
|
|
|
|
+// receivedPacket represents a packet received from a peer connection
|
|
|
+type receivedPacket struct {
|
|
|
+ data []byte
|
|
|
+ endpoint conn.Endpoint
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
// reduce duplicated code
|
|
|
type netBind struct {
|
|
|
dns dns.Client
|
|
|
dnsOption dns.IPOption
|
|
|
|
|
|
- workers int
|
|
|
- readQueue chan *netReadInfo
|
|
|
+ workers int
|
|
|
+ readQueue chan *netReadInfo
|
|
|
+ packetQueue chan *receivedPacket
|
|
|
+ startedMutex sync.Mutex
|
|
|
+ started bool
|
|
|
}
|
|
|
|
|
|
// SetMark implements conn.Bind
|
|
|
@@ -80,6 +90,35 @@ func (bind *netBind) BatchSize() int {
|
|
|
// Open implements conn.Bind
|
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
bind.readQueue = make(chan *netReadInfo)
|
|
|
+ bind.packetQueue = make(chan *receivedPacket, 100)
|
|
|
+
|
|
|
+ // Start a dispatcher goroutine that matches readQueue requests with received packets
|
|
|
+ bind.startedMutex.Lock()
|
|
|
+ if !bind.started {
|
|
|
+ bind.started = true
|
|
|
+ go func() {
|
|
|
+ for {
|
|
|
+ packet, ok := <-bind.packetQueue
|
|
|
+ if !ok {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait for a read request from WireGuard
|
|
|
+ request, ok := <-bind.readQueue
|
|
|
+ if !ok {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy packet data to the request buffer
|
|
|
+ n := copy(request.buff, packet.data)
|
|
|
+ request.bytes = n
|
|
|
+ request.endpoint = packet.endpoint
|
|
|
+ request.err = packet.err
|
|
|
+ request.waiter.Done()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ bind.startedMutex.Unlock()
|
|
|
|
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
|
defer func() {
|
|
|
@@ -115,6 +154,9 @@ func (bind *netBind) Close() error {
|
|
|
if bind.readQueue != nil {
|
|
|
close(bind.readQueue)
|
|
|
}
|
|
|
+ if bind.packetQueue != nil {
|
|
|
+ close(bind.packetQueue)
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -133,30 +175,43 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
}
|
|
|
endpoint.conn = c
|
|
|
|
|
|
- go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
|
|
|
+ // Start a goroutine that continuously reads from this connection
|
|
|
+ // and sends received packets to the packet queue
|
|
|
+ go func(conn net.Conn, endpoint *netEndpoint) {
|
|
|
+ const maxPacketSize = 1500
|
|
|
for {
|
|
|
- v, ok := <-readQueue
|
|
|
- if !ok {
|
|
|
- return
|
|
|
+ buf := make([]byte, maxPacketSize)
|
|
|
+ n, err := conn.Read(buf)
|
|
|
+
|
|
|
+ if n > 3 {
|
|
|
+ // Clear reserved bytes
|
|
|
+ buf[1] = 0
|
|
|
+ buf[2] = 0
|
|
|
+ buf[3] = 0
|
|
|
}
|
|
|
- i, err := c.Read(v.buff)
|
|
|
-
|
|
|
- if i > 3 {
|
|
|
- v.buff[1] = 0
|
|
|
- v.buff[2] = 0
|
|
|
- v.buff[3] = 0
|
|
|
+
|
|
|
+ packet := &receivedPacket{
|
|
|
+ data: buf[:n],
|
|
|
+ endpoint: endpoint,
|
|
|
+ err: err,
|
|
|
}
|
|
|
-
|
|
|
- v.bytes = i
|
|
|
- v.endpoint = endpoint
|
|
|
- v.err = err
|
|
|
- v.waiter.Done()
|
|
|
+
|
|
|
+ // Try to send packet to queue; if queue is full or closed, exit
|
|
|
+ select {
|
|
|
+ case bind.packetQueue <- packet:
|
|
|
+ // Packet sent successfully
|
|
|
+ default:
|
|
|
+ // Queue is full or closed, exit goroutine
|
|
|
+ endpoint.conn = nil
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
if err != nil {
|
|
|
endpoint.conn = nil
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- }(bind.readQueue, endpoint)
|
|
|
+ }(c, endpoint)
|
|
|
|
|
|
return nil
|
|
|
}
|