Browse Source

Fix wireguard close

世界 2 years ago
parent
commit
0ad1bbea11
1 changed files with 20 additions and 10 deletions
  1. 20 10
      transport/wireguard/device_stack.go

+ 20 - 10
transport/wireguard/device_stack.go

@@ -34,6 +34,7 @@ type StackDevice struct {
 	mtu        uint32
 	mtu        uint32
 	events     chan tun.Event
 	events     chan tun.Event
 	outbound   chan *stack.PacketBuffer
 	outbound   chan *stack.PacketBuffer
+	done       chan struct{}
 	dispatcher stack.NetworkDispatcher
 	dispatcher stack.NetworkDispatcher
 	addr4      tcpip.Address
 	addr4      tcpip.Address
 	addr6      tcpip.Address
 	addr6      tcpip.Address
@@ -50,6 +51,7 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
 		mtu:      mtu,
 		mtu:      mtu,
 		events:   make(chan tun.Event, 1),
 		events:   make(chan tun.Event, 1),
 		outbound: make(chan *stack.PacketBuffer, 256),
 		outbound: make(chan *stack.PacketBuffer, 256),
+		done:     make(chan struct{}),
 	}
 	}
 	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
 	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
 	if err != nil {
 	if err != nil {
@@ -140,16 +142,20 @@ func (w *StackDevice) File() *os.File {
 }
 }
 
 
 func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
 func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
-	packetBuffer, ok := <-w.outbound
-	if !ok {
+	select {
+	case packetBuffer, ok := <-w.outbound:
+		if !ok {
+			return 0, os.ErrClosed
+		}
+		defer packetBuffer.DecRef()
+		p = p[offset:]
+		for _, slice := range packetBuffer.AsSlices() {
+			n += copy(p[n:], slice)
+		}
+		return
+	case <-w.done:
 		return 0, os.ErrClosed
 		return 0, os.ErrClosed
 	}
 	}
-	defer packetBuffer.DecRef()
-	p = p[offset:]
-	for _, slice := range packetBuffer.AsSlices() {
-		n += copy(p[n:], slice)
-	}
-	return
 }
 }
 
 
 func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
 func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
@@ -201,7 +207,7 @@ func (w *StackDevice) Close() error {
 		endpoint.Abort()
 		endpoint.Abort()
 	}
 	}
 	w.stack.Wait()
 	w.stack.Wait()
-	close(w.outbound)
+	close(w.done)
 	return nil
 	return nil
 }
 }
 
 
@@ -246,7 +252,11 @@ func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
 func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
 func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
 	for _, packetBuffer := range list.AsSlice() {
 	for _, packetBuffer := range list.AsSlice() {
 		packetBuffer.IncRef()
 		packetBuffer.IncRef()
-		ep.outbound <- packetBuffer
+		select {
+		case <-ep.done:
+			return 0, &tcpip.ErrClosedForSend{}
+		case ep.outbound <- packetBuffer:
+		}
 	}
 	}
 	return list.Len(), nil
 	return list.Len(), nil
 }
 }