浏览代码

lib/connections: Do not leak FDs, fix address copy (fixes #5767) (#5768)

* lib/connections: Do not leak FDs, fix address copy (fixes #5767)

* build

* Update quic_listen.go

* Update quic_listen.go
Audrius Butkevicius 6 年之前
父节点
当前提交
ee746263fb
共有 3 个文件被更改,包括 23 次插入10 次删除
  1. 10 7
      lib/connections/quic_dial.go
  2. 3 2
      lib/connections/quic_listen.go
  3. 10 1
      lib/connections/quic_misc.go

+ 10 - 7
lib/connections/quic_dial.go

@@ -45,23 +45,26 @@ func (d *quicDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, err
 	}
 
 	var conn net.PacketConn
-	closeConn := false
+	// We need to track who created the conn.
+	// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
+	// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
+	var createdConn net.PacketConn
 	if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
 		conn = listenConn.(net.PacketConn)
 	} else {
 		if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {
 			return internalConn{}, err
 		} else {
-			closeConn = true
 			conn = packetConn
+			createdConn = packetConn
 		}
 	}
 
 	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
 	session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig)
 	if err != nil {
-		if closeConn {
-			_ = conn.Close()
+		if createdConn != nil {
+			_ = createdConn.Close()
 		}
 		return internalConn{}, err
 	}
@@ -85,13 +88,13 @@ func (d *quicDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, err
 	if err != nil {
 		// It's ok to close these, this does not close the underlying packetConn.
 		_ = session.Close()
-		if closeConn {
-			_ = conn.Close()
+		if createdConn != nil {
+			_ = createdConn.Close()
 		}
 		return internalConn{}, err
 	}
 
-	return internalConn{&quicTlsConn{session, stream}, connTypeQUICClient, quicPriority}, nil
+	return internalConn{&quicTlsConn{session, stream, createdConn}, connTypeQUICClient, quicPriority}, nil
 }
 
 func (d *quicDialer) RedialFrequency() time.Duration {

+ 3 - 2
lib/connections/quic_listen.go

@@ -59,7 +59,8 @@ func (t *quicListener) OnNATTypeChanged(natType stun.NATType) {
 func (t *quicListener) OnExternalAddressChanged(address *stun.Host, via string) {
 	var uri *url.URL
 	if address != nil {
-		uri = &(*t.uri)
+		copy := *t.uri
+		uri = &copy
 		uri.Host = address.TransportAddr()
 	}
 
@@ -165,7 +166,7 @@ func (t *quicListener) Serve() {
 			continue
 		}
 
-		t.conns <- internalConn{&quicTlsConn{session, stream}, connTypeQUICServer, quicPriority}
+		t.conns <- internalConn{&quicTlsConn{session, stream, nil}, connTypeQUICServer, quicPriority}
 	}
 }
 

+ 10 - 1
lib/connections/quic_misc.go

@@ -24,15 +24,24 @@ var (
 type quicTlsConn struct {
 	quic.Session
 	quic.Stream
+	// If we created this connection, we should be the ones closing it.
+	createdConn net.PacketConn
 }
 
 func (q *quicTlsConn) Close() error {
 	sterr := q.Stream.Close()
 	seerr := q.Session.Close()
+	var pcerr error
+	if q.createdConn != nil {
+		pcerr = q.createdConn.Close()
+	}
 	if sterr != nil {
 		return sterr
 	}
-	return seerr
+	if seerr != nil {
+		return seerr
+	}
+	return pcerr
 }
 
 // Sort available packet connections by ip address, preferring unspecified local address.