Преглед изворни кода

tstest/natlab: fix conn.Close race with conn.ReadFromUDPAddrPort (#16710)

If a conn.Close call raced conn.ReadFromUDPAddrPort before it could
"register" itself as an active read, the conn.ReadFromUDPAddrPort would
never return.

This commit replaces all the activeRead and breakActiveReads machinery
with a channel. These constructs were only depended upon by
SetReadDeadline, and SetReadDeadline was unused.

Updates #16707

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited пре 7 месеци
родитељ
комит
3d1e4f147a
1 измењених фајлова са 23 додато и 89 уклоњено
  1. 23 89
      tstest/natlab/natlab.go

+ 23 - 89
tstest/natlab/natlab.go

@@ -684,10 +684,11 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne
 	ipp := netip.AddrPortFrom(ip, port)
 
 	c := &conn{
-		m:   m,
-		fam: fam,
-		ipp: ipp,
-		in:  make(chan *Packet, 100), // arbitrary
+		m:        m,
+		fam:      fam,
+		ipp:      ipp,
+		closedCh: make(chan struct{}),
+		in:       make(chan *Packet, 100), // arbitrary
 	}
 	switch c.fam {
 	case 0:
@@ -716,70 +717,28 @@ type conn struct {
 	fam uint8 // 0, 4, or 6
 	ipp netip.AddrPort
 
-	mu           sync.Mutex
-	closed       bool
-	readDeadline time.Time
-	activeReads  map[*activeRead]bool
-	in           chan *Packet
-}
+	closeOnce sync.Once
+	closedCh  chan struct{} // closed by Close
 
-type activeRead struct {
-	cancel context.CancelFunc
-}
-
-// canRead reports whether we can do a read.
-func (c *conn) canRead() error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.closed {
-		return net.ErrClosed
-	}
-	if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) {
-		return errors.New("read deadline exceeded")
-	}
-	return nil
-}
-
-func (c *conn) registerActiveRead(ar *activeRead, active bool) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.activeReads == nil {
-		c.activeReads = make(map[*activeRead]bool)
-	}
-	if active {
-		c.activeReads[ar] = true
-	} else {
-		delete(c.activeReads, ar)
-	}
+	in chan *Packet
 }
 
 func (c *conn) Close() error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.closed {
-		return nil
-	}
-	c.closed = true
-	switch c.fam {
-	case 0:
-		c.m.unregisterConn4(c)
-		c.m.unregisterConn6(c)
-	case 4:
-		c.m.unregisterConn4(c)
-	case 6:
-		c.m.unregisterConn6(c)
-	}
-	c.breakActiveReadsLocked()
+	c.closeOnce.Do(func() {
+		switch c.fam {
+		case 0:
+			c.m.unregisterConn4(c)
+			c.m.unregisterConn6(c)
+		case 4:
+			c.m.unregisterConn4(c)
+		case 6:
+			c.m.unregisterConn6(c)
+		}
+		close(c.closedCh)
+	})
 	return nil
 }
 
-func (c *conn) breakActiveReadsLocked() {
-	for ar := range c.activeReads {
-		ar.cancel()
-	}
-	c.activeReads = nil
-}
-
 func (c *conn) LocalAddr() net.Addr {
 	return &net.UDPAddr{
 		IP:   c.ipp.Addr().AsSlice(),
@@ -809,25 +768,13 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
 }
 
 func (c *conn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-
-	ar := &activeRead{cancel: cancel}
-
-	if err := c.canRead(); err != nil {
-		return 0, netip.AddrPort{}, err
-	}
-
-	c.registerActiveRead(ar, true)
-	defer c.registerActiveRead(ar, false)
-
 	select {
+	case <-c.closedCh:
+		return 0, netip.AddrPort{}, net.ErrClosed
 	case pkt := <-c.in:
 		n = copy(p, pkt.Payload)
 		pkt.Trace("PacketConn.ReadFrom")
 		return n, pkt.Src, nil
-	case <-ctx.Done():
-		return 0, netip.AddrPort{}, context.DeadlineExceeded
 	}
 }
 
@@ -857,18 +804,5 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
 	panic("SetWriteDeadline unsupported; TODO when needed")
 }
 func (c *conn) SetReadDeadline(t time.Time) error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-
-	now := time.Now()
-	if t.After(now) {
-		panic("SetReadDeadline in the future not yet supported; TODO?")
-	}
-
-	if !t.IsZero() && t.Before(now) {
-		c.breakActiveReadsLocked()
-	}
-	c.readDeadline = t
-
-	return nil
+	panic("SetReadDeadline unsupported; TODO when needed")
 }