浏览代码

Add callClose to UDP Dispatcher

Fixes https://github.com/XTLS/Xray-core/issues/1611
RPRX 2 年之前
父节点
当前提交
9046eda5ce
共有 1 个文件被更改,包括 28 次插入11 次删除
  1. 28 11
      transport/internet/udp/dispatcher.go

+ 28 - 11
transport/internet/udp/dispatcher.go

@@ -31,6 +31,7 @@ type Dispatcher struct {
 	conns      map[net.Destination]*connEntry
 	dispatcher routing.Dispatcher
 	callback   ResponseCallback
+	callClose  func() error
 }
 
 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@@ -79,7 +80,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
 		cancel: removeRay,
 	}
 	v.conns[dest] = entry
-	go handleInput(ctx, entry, dest, v.callback)
+	go handleInput(ctx, entry, dest, v.callback, v.callClose)
 	return entry, nil
 }
 
@@ -102,8 +103,13 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
 	}
 }
 
-func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
-	defer conn.cancel()
+func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
+	defer func() {
+		conn.cancel()
+		if callClose != nil {
+			callClose()
+		}
+	}()
 
 	input := conn.link.Reader
 	timer := conn.timer
@@ -144,7 +150,12 @@ func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.Pac
 		done:  done.New(),
 	}
 
-	d := NewDispatcher(dispatcher, c.callback)
+	d := &Dispatcher{
+		conns:      make(map[net.Destination]*connEntry),
+		dispatcher: dispatcher,
+		callback:   c.callback,
+		callClose:  c.Close,
+	}
 	c.dispatcher = d
 	return c, nil
 }
@@ -162,16 +173,22 @@ func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
 }
 
 func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
+	var packet *udp.Packet
+s:
 	select {
 	case <-c.done.Wait():
-		return 0, nil, io.EOF
-	case packet := <-c.cache:
-		n := copy(p, packet.Payload.Bytes())
-		return n, &net.UDPAddr{
-			IP:   packet.Source.Address.IP(),
-			Port: int(packet.Source.Port),
-		}, nil
+		select {
+		case packet = <-c.cache:
+			break s
+		default:
+			return 0, nil, io.EOF
+		}
+	case packet = <-c.cache:
 	}
+	return copy(p, packet.Payload.Bytes()), &net.UDPAddr{
+		IP:   packet.Source.Address.IP(),
+		Port: int(packet.Source.Port),
+	}, nil
 }
 
 func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {