Browse Source

Fix possible memory leak (circular shared_ptr) in networking

Ivan Savenko 5 months ago
parent
commit
cd2837a84e

+ 5 - 5
lib/network/NetworkConnection.cpp

@@ -12,9 +12,9 @@
 
 VCMI_LIB_NAMESPACE_BEGIN
 
-NetworkConnection::NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, const std::shared_ptr<NetworkContext> & context)
+NetworkConnection::NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, NetworkContext & context)
 	: socket(socket)
-	, timer(std::make_shared<NetworkTimer>(*context))
+	, timer(std::make_shared<NetworkTimer>(context))
 	, listener(listener)
 {
 	socket->set_option(boost::asio::ip::tcp::no_delay(true));
@@ -208,7 +208,7 @@ void NetworkConnection::close()
 	//NOTE: ignoring error code, intended
 }
 
-InternalConnection::InternalConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkContext> & context)
+InternalConnection::InternalConnection(INetworkConnectionListener & listener, NetworkContext & context)
 	: io(context)
 	, listener(listener)
 {
@@ -216,7 +216,7 @@ InternalConnection::InternalConnection(INetworkConnectionListener & listener, co
 
 void InternalConnection::receivePacket(const std::vector<std::byte> & message)
 {
-	boost::asio::post(*io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this()), message](){
+	boost::asio::post(io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this()), message](){
 		if (self->connectionActive)
 			self->listener.onPacketReceived(self, message);
 	});
@@ -224,7 +224,7 @@ void InternalConnection::receivePacket(const std::vector<std::byte> & message)
 
 void InternalConnection::disconnect()
 {
-	boost::asio::post(*io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this())](){
+	boost::asio::post(io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this())](){
 		self->listener.onDisconnected(self, "Internal connection has been terminated");
 		self->otherSideWeak.reset();
 		self->connectionActive = false;

+ 3 - 3
lib/network/NetworkConnection.h

@@ -38,7 +38,7 @@ class NetworkConnection final : public INetworkConnection, public std::enable_sh
 	void onDataSent(const boost::system::error_code & ec);
 
 public:
-	NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, const std::shared_ptr<NetworkContext> & context);
+	NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, NetworkContext & context);
 
 	void start();
 	void close() override;
@@ -49,11 +49,11 @@ public:
 class InternalConnection final : public IInternalConnection, public std::enable_shared_from_this<InternalConnection>
 {
 	std::weak_ptr<IInternalConnection> otherSideWeak;
-	std::shared_ptr<NetworkContext> io;
+	NetworkContext & io;
 	INetworkConnectionListener & listener;
 	bool connectionActive = false;
 public:
-	InternalConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkContext> & context);
+	InternalConnection(INetworkConnectionListener & listener, NetworkContext & context);
 
 	void receivePacket(const std::vector<std::byte> & message) override;
 	void disconnect() override;

+ 4 - 4
lib/network/NetworkHandler.cpp

@@ -21,12 +21,12 @@ std::unique_ptr<INetworkHandler> INetworkHandler::createHandler()
 }
 
 NetworkHandler::NetworkHandler()
-	: io(std::make_shared<NetworkContext>())
+	: io(std::make_unique<NetworkContext>())
 {}
 
 std::unique_ptr<INetworkServer> NetworkHandler::createServerTCP(INetworkServerListener & listener)
 {
-	return std::make_unique<NetworkServer>(listener, io);
+	return std::make_unique<NetworkServer>(listener, *io);
 }
 
 void NetworkHandler::connectToRemote(INetworkClientListener & listener, const std::string & host, uint16_t port)
@@ -50,7 +50,7 @@ void NetworkHandler::connectToRemote(INetworkClientListener & listener, const st
 				listener.onConnectionFailed(error.message());
 				return;
 			}
-			auto connection = std::make_shared<NetworkConnection>(listener, socket, io);
+			auto connection = std::make_shared<NetworkConnection>(listener, socket, *io);
 			connection->start();
 
 			listener.onConnectionEstablished(connection);
@@ -75,7 +75,7 @@ void NetworkHandler::createTimer(INetworkTimerListener & listener, std::chrono::
 
 void NetworkHandler::createInternalConnection(INetworkClientListener & listener, INetworkServer & server)
 {
-	auto localConnection = std::make_shared<InternalConnection>(listener, io);
+	auto localConnection = std::make_shared<InternalConnection>(listener, *io);
 
 	server.receiveInternalConnection(localConnection);
 

+ 1 - 1
lib/network/NetworkHandler.h

@@ -15,7 +15,7 @@ VCMI_LIB_NAMESPACE_BEGIN
 
 class NetworkHandler : public INetworkHandler
 {
-	std::shared_ptr<NetworkContext> io;
+	std::unique_ptr<NetworkContext> io;
 
 public:
 	NetworkHandler();

+ 3 - 3
lib/network/NetworkServer.cpp

@@ -13,7 +13,7 @@
 
 VCMI_LIB_NAMESPACE_BEGIN
 
-NetworkServer::NetworkServer(INetworkServerListener & listener, const std::shared_ptr<NetworkContext> & context)
+NetworkServer::NetworkServer(INetworkServerListener & listener, NetworkContext & context)
 	: io(context)
 	, listener(listener)
 {
@@ -21,13 +21,13 @@ NetworkServer::NetworkServer(INetworkServerListener & listener, const std::share
 
 uint16_t NetworkServer::start(uint16_t port)
 {
-	acceptor = std::make_shared<NetworkAcceptor>(*io, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), port));
+	acceptor = std::make_shared<NetworkAcceptor>(io, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), port));
 	return startAsyncAccept();
 }
 
 uint16_t NetworkServer::startAsyncAccept()
 {
-	auto upcomingConnection = std::make_shared<NetworkSocket>(*io);
+	auto upcomingConnection = std::make_shared<NetworkSocket>(io);
 	acceptor->async_accept(*upcomingConnection, [this, upcomingConnection](const auto & ec) { connectionAccepted(upcomingConnection, ec); });
 	return acceptor->local_endpoint().port();
 }

+ 2 - 2
lib/network/NetworkServer.h

@@ -15,7 +15,7 @@ VCMI_LIB_NAMESPACE_BEGIN
 
 class NetworkServer : public INetworkConnectionListener, public INetworkServer
 {
-	std::shared_ptr<NetworkContext> io;
+	NetworkContext & io;
 	std::shared_ptr<NetworkAcceptor> acceptor;
 	std::set<std::shared_ptr<INetworkConnection>> connections;
 
@@ -27,7 +27,7 @@ class NetworkServer : public INetworkConnectionListener, public INetworkServer
 	void onDisconnected(const std::shared_ptr<INetworkConnection> & connection, const std::string & errorMessage) override;
 	void onPacketReceived(const std::shared_ptr<INetworkConnection> & connection, const std::vector<std::byte> & message) override;
 public:
-	NetworkServer(INetworkServerListener & listener, const std::shared_ptr<NetworkContext> & context);
+	NetworkServer(INetworkServerListener & listener, NetworkContext & context);
 
 	void receiveInternalConnection(std::shared_ptr<IInternalConnection> remoteConnection) override;