浏览代码

Better handling of disconnects, code cleanup

Ivan Savenko 1 年之前
父节点
当前提交
f97ffd8e9a

+ 13 - 4
lib/network/NetworkConnection.cpp

@@ -26,7 +26,7 @@ void NetworkConnection::start()
 	boost::asio::async_read(*socket,
 							readBuffer,
 							boost::asio::transfer_exactly(messageHeaderSize),
-							[this](const auto & ec, const auto & endpoint) { onHeaderReceived(ec); });
+							[self = shared_from_this()](const auto & ec, const auto & endpoint) { self->onHeaderReceived(ec); });
 }
 
 void NetworkConnection::onHeaderReceived(const boost::system::error_code & ec)
@@ -42,7 +42,7 @@ void NetworkConnection::onHeaderReceived(const boost::system::error_code & ec)
 	boost::asio::async_read(*socket,
 							readBuffer,
 							boost::asio::transfer_exactly(messageSize),
-							[this, messageSize](const auto & ec, const auto & endpoint) { onPacketReceived(ec, messageSize); });
+							[self = shared_from_this(), messageSize](const auto & ec, const auto & endpoint) { self->onPacketReceived(ec, messageSize); });
 }
 
 uint32_t NetworkConnection::readPacketSize()
@@ -54,7 +54,7 @@ uint32_t NetworkConnection::readPacketSize()
 	readBuffer.sgetn(reinterpret_cast<char *>(&messageSize), sizeof(messageSize));
 
 	if (messageSize > messageMaxSize)
-		throw std::runtime_error("Invalid packet size!");
+		listener.onDisconnected(shared_from_this(), "Invalid packet size!");
 
 	return messageSize;
 }
@@ -88,7 +88,16 @@ void NetworkConnection::sendPacket(const std::vector<std::byte> & message)
 	boost::asio::write(*socket, boost::asio::buffer(messageSize), ec );
 	boost::asio::write(*socket, boost::asio::buffer(message), ec );
 
-	// FIXME: handle error?
+	if (ec)
+		listener.onDisconnected(shared_from_this(), ec.message());
+}
+
+void NetworkConnection::close()
+{
+	boost::system::error_code ec;
+	socket->close(ec);
+
+	//NOTE: ignoring error code
 }
 
 VCMI_LIB_NAMESPACE_END

+ 2 - 1
lib/network/NetworkConnection.h

@@ -13,7 +13,7 @@
 
 VCMI_LIB_NAMESPACE_BEGIN
 
-class NetworkConnection : public INetworkConnection, public std::enable_shared_from_this<NetworkConnection>
+class NetworkConnection : public INetworkConnection, std::enable_shared_from_this<NetworkConnection>
 {
 	static const int messageHeaderSize = sizeof(uint32_t);
 	static const int messageMaxSize = 64 * 1024 * 1024; // arbitrary size to prevent potential massive allocation if we receive garbage input
@@ -31,6 +31,7 @@ public:
 	NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket);
 
 	void start();
+	void close() override;
 	void sendPacket(const std::vector<std::byte> & message) override;
 };
 

+ 1 - 2
lib/network/NetworkInterface.h

@@ -17,6 +17,7 @@ class DLL_LINKAGE INetworkConnection : boost::noncopyable
 public:
 	virtual ~INetworkConnection() = default;
 	virtual void sendPacket(const std::vector<std::byte> & message) = 0;
+	virtual void close() = 0;
 };
 
 using NetworkConnectionPtr = std::shared_ptr<INetworkConnection>;
@@ -38,8 +39,6 @@ class DLL_LINKAGE INetworkServer : boost::noncopyable
 public:
 	virtual ~INetworkServer() = default;
 
-	virtual void sendPacket(const std::shared_ptr<INetworkConnection> &, const std::vector<std::byte> & message) = 0;
-	virtual void closeConnection(const std::shared_ptr<INetworkConnection> &) = 0;
 	virtual void start(uint16_t port) = 0;
 };
 

+ 3 - 18
lib/network/NetworkServer.cpp

@@ -28,7 +28,7 @@ void NetworkServer::start(uint16_t port)
 void NetworkServer::startAsyncAccept()
 {
 	auto upcomingConnection = std::make_shared<NetworkSocket>(*io);
-	acceptor->async_accept(*upcomingConnection, std::bind(&NetworkServer::connectionAccepted, this, upcomingConnection, _1));
+	acceptor->async_accept(*upcomingConnection, [this, upcomingConnection](const auto & ec) { connectionAccepted(upcomingConnection, ec); });
 }
 
 void NetworkServer::connectionAccepted(std::shared_ptr<NetworkSocket> upcomingConnection, const boost::system::error_code & ec)
@@ -46,27 +46,12 @@ void NetworkServer::connectionAccepted(std::shared_ptr<NetworkSocket> upcomingCo
 	startAsyncAccept();
 }
 
-void NetworkServer::sendPacket(const std::shared_ptr<INetworkConnection> & connection, const std::vector<std::byte> & message)
-{
-	connection->sendPacket(message);
-}
-
-void NetworkServer::closeConnection(const std::shared_ptr<INetworkConnection> & connection)
-{
-	logNetwork->info("Closing connection!");
-	assert(connections.count(connection));
-	connections.erase(connection);
-}
-
 void NetworkServer::onDisconnected(const std::shared_ptr<INetworkConnection> & connection, const std::string & errorMessage)
 {
 	logNetwork->info("Connection lost! Reason: %s", errorMessage);
 	assert(connections.count(connection));
-	if (connections.count(connection)) // how? Connection was explicitly closed before?
-	{
-		connections.erase(connection);
-		listener.onDisconnected(connection, errorMessage);
-	}
+	connections.erase(connection);
+	listener.onDisconnected(connection, errorMessage);
 }
 
 void NetworkServer::onPacketReceived(const std::shared_ptr<INetworkConnection> & connection, const std::vector<std::byte> & message)

+ 0 - 3
lib/network/NetworkServer.h

@@ -29,9 +29,6 @@ class NetworkServer : public INetworkConnectionListener, public INetworkServer
 public:
 	NetworkServer(INetworkServerListener & listener, const std::shared_ptr<NetworkContext> & context);
 
-	void sendPacket(const std::shared_ptr<INetworkConnection> &, const std::vector<std::byte> & message) override;
-	void closeConnection(const std::shared_ptr<INetworkConnection> &) override;
-
 	void start(uint16_t port) override;
 };
 

+ 2 - 2
lib/serializer/CMemorySerializer.cpp

@@ -17,7 +17,7 @@ int CMemorySerializer::read(std::byte * data, unsigned size)
 	if(buffer.size() < readPos + size)
 		throw std::runtime_error(boost::str(boost::format("Cannot read past the buffer (accessing index %d, while size is %d)!") % (readPos + size - 1) % buffer.size()));
 
-	std::memcpy(data, buffer.data() + readPos, size);
+	std::copy_n(buffer.data() + readPos, size, data);
 	readPos += size;
 	return size;
 }
@@ -26,7 +26,7 @@ int CMemorySerializer::write(const std::byte * data, unsigned size)
 {
 	auto oldSize = buffer.size(); //and the pos to write from
 	buffer.resize(oldSize + size);
-	std::memcpy(buffer.data() + oldSize, data, size);
+	std::copy_n(data, size, buffer.data() + oldSize);
 	return size;
 }
 

+ 1 - 1
lib/serializer/CMemorySerializer.h

@@ -18,7 +18,7 @@ VCMI_LIB_NAMESPACE_BEGIN
 class DLL_LINKAGE CMemorySerializer
 	: public IBinaryReader, public IBinaryWriter
 {
-	std::vector<ui8> buffer;
+	std::vector<std::byte> buffer;
 
 	size_t readPos; //index of the next byte to be read
 public:

+ 12 - 24
lobby/LobbyDatabase.cpp

@@ -149,12 +149,6 @@ void LobbyDatabase::prepareStatements()
 		WHERE roomID = ?
 	)";
 
-	static const std::string setGameRoomPlayerLimitText = R"(
-		UPDATE gameRooms
-		SET playerLimit = ?
-		WHERE roomID = ?
-	)";
-
 	// SELECT FROM
 
 	static const std::string getRecentMessageHistoryText = R"(
@@ -221,7 +215,7 @@ void LobbyDatabase::prepareStatements()
 	static const std::string isAccountCookieValidText = R"(
 		SELECT COUNT(accountID)
 		FROM accountCookies
-		WHERE accountID = ? AND cookieUUID = ? AND strftime('%s',CURRENT_TIMESTAMP)- strftime('%s',creationTime) < ?
+		WHERE accountID = ? AND cookieUUID = ?
 	)";
 
 	static const std::string isGameRoomCookieValidText = R"(
@@ -269,7 +263,6 @@ void LobbyDatabase::prepareStatements()
 
 	setAccountOnlineStatement = database->prepare(setAccountOnlineText);
 	setGameRoomStatusStatement = database->prepare(setGameRoomStatusText);
-	setGameRoomPlayerLimitStatement = database->prepare(setGameRoomPlayerLimitText);
 
 	getRecentMessageHistoryStatement = database->prepare(getRecentMessageHistoryText);
 	getIdleGameRoomStatement = database->prepare(getIdleGameRoomText);
@@ -352,11 +345,6 @@ void LobbyDatabase::setGameRoomStatus(const std::string & roomID, LobbyRoomState
 	setGameRoomStatusStatement->executeOnce(vstd::to_underlying(roomStatus), roomID);
 }
 
-void LobbyDatabase::setGameRoomPlayerLimit(const std::string & roomID, uint32_t playerLimit)
-{
-	setGameRoomPlayerLimitStatement->executeOnce(playerLimit, roomID);
-}
-
 void LobbyDatabase::insertPlayerIntoGameRoom(const std::string & accountID, const std::string & roomID)
 {
 	insertGameRoomPlayersStatement->executeOnce(roomID, accountID);
@@ -392,11 +380,10 @@ void LobbyDatabase::insertAccessCookie(const std::string & accountID, const std:
 	insertAccessCookieStatement->executeOnce(accountID, accessCookieUUID);
 }
 
-void LobbyDatabase::updateAccessCookie(const std::string & accountID, const std::string & accessCookieUUID) {}
-
-void LobbyDatabase::updateAccountLoginTime(const std::string & accountID) {}
-
-void LobbyDatabase::updateActiveAccount(const std::string & accountID, bool isActive) {}
+void LobbyDatabase::updateAccountLoginTime(const std::string & accountID)
+{
+	assert(0);
+}
 
 std::string LobbyDatabase::getAccountDisplayName(const std::string & accountID)
 {
@@ -410,16 +397,16 @@ std::string LobbyDatabase::getAccountDisplayName(const std::string & accountID)
 	return result;
 }
 
-LobbyCookieStatus LobbyDatabase::getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime)
-{
-	return {};
-}
+//LobbyCookieStatus LobbyDatabase::getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID)
+//{
+//	return {};
+//}
 
-LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime)
+LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID)
 {
 	bool result = false;
 
-	isAccountCookieValidStatement->setBinds(accountID, accessCookieUUID, cookieLifetime.count());
+	isAccountCookieValidStatement->setBinds(accountID, accessCookieUUID);
 	if(isAccountCookieValidStatement->execute())
 		isAccountCookieValidStatement->getColumns(result);
 	isAccountCookieValidStatement->reset();
@@ -429,6 +416,7 @@ LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & acco
 
 LobbyInviteStatus LobbyDatabase::getAccountInviteStatus(const std::string & accountID, const std::string & roomID)
 {
+	assert(0);
 	return {};
 }
 

+ 3 - 6
lobby/LobbyDatabase.h

@@ -62,7 +62,6 @@ public:
 
 	void setAccountOnline(const std::string & accountID, bool isOnline);
 	void setGameRoomStatus(const std::string & roomID, LobbyRoomState roomStatus);
-	void setGameRoomPlayerLimit(const std::string & roomID, uint32_t playerLimit);
 
 	void insertPlayerIntoGameRoom(const std::string & accountID, const std::string & roomID);
 	void deletePlayerFromGameRoom(const std::string & accountID, const std::string & roomID);
@@ -75,21 +74,19 @@ public:
 	void insertAccessCookie(const std::string & accountID, const std::string & accessCookieUUID);
 	void insertChatMessage(const std::string & sender, const std::string & roomType, const std::string & roomID, const std::string & messageText);
 
-	void updateAccessCookie(const std::string & accountID, const std::string & accessCookieUUID);
 	void updateAccountLoginTime(const std::string & accountID);
-	void updateActiveAccount(const std::string & accountID, bool isActive);
 
 	std::vector<LobbyGameRoom> getActiveGameRooms();
 	std::vector<LobbyAccount> getActiveAccounts();
-	std::vector<LobbyAccount> getAccountsInRoom(const std::string & roomID);
+//	std::vector<LobbyAccount> getAccountsInRoom(const std::string & roomID);
 	std::vector<LobbyChatMessage> getRecentMessageHistory();
 
 	std::string getIdleGameRoom(const std::string & hostAccountID);
 	std::string getAccountGameRoom(const std::string & accountID);
 	std::string getAccountDisplayName(const std::string & accountID);
 
-	LobbyCookieStatus getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime);
-	LobbyCookieStatus getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime);
+//	LobbyCookieStatus getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID);
+	LobbyCookieStatus getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID);
 	LobbyInviteStatus getAccountInviteStatus(const std::string & accountID, const std::string & roomID);
 	LobbyRoomState getGameRoomStatus(const std::string & roomID);
 	uint32_t getGameRoomFreeSlots(const std::string & roomID);

+ 0 - 1
lobby/LobbyDefines.h

@@ -36,7 +36,6 @@ struct LobbyChatMessage
 enum class LobbyCookieStatus : int32_t
 {
 	INVALID,
-	EXPIRED,
 	VALID
 };
 

+ 26 - 24
lobby/LobbyServer.cpp

@@ -17,8 +17,6 @@
 #include <boost/uuid/uuid_generators.hpp>
 #include <boost/uuid/uuid_io.hpp>
 
-static const auto accountCookieLifetime = std::chrono::hours(24 * 7);
-
 bool LobbyServer::isAccountNameValid(const std::string & accountName) const
 {
 	if(accountName.size() < 4)
@@ -60,7 +58,7 @@ NetworkConnectionPtr LobbyServer::findGameRoom(const std::string & gameRoomID) c
 
 void LobbyServer::sendMessage(const NetworkConnectionPtr & target, const JsonNode & json)
 {
-	networkServer->sendPacket(target, json.toBytes(true));
+	target->sendPacket(json.toBytes(true));
 }
 
 void LobbyServer::sendAccountCreated(const NetworkConnectionPtr & target, const std::string & accountID, const std::string & accountCookie)
@@ -206,16 +204,27 @@ void LobbyServer::onNewConnection(const NetworkConnectionPtr & connection)
 void LobbyServer::onDisconnected(const NetworkConnectionPtr & connection, const std::string & errorMessage)
 {
 	if(activeAccounts.count(connection))
+	{
 		database->setAccountOnline(activeAccounts.at(connection), false);
+		activeAccounts.erase(connection);
+	}
 
 	if(activeGameRooms.count(connection))
+	{
 		database->setGameRoomStatus(activeGameRooms.at(connection), LobbyRoomState::CLOSED);
+		activeGameRooms.erase(connection);
+	}
+
+	if(activeProxies.count(connection))
+	{
+		auto & otherConnection = activeProxies.at(connection);
+
+		if (otherConnection)
+			otherConnection->close();
 
-	// NOTE: lost connection can be in only one of these lists (or in none of them)
-	// calling on all possible containers since calling std::map::erase() with non-existing key is legal
-	activeAccounts.erase(connection);
-	activeProxies.erase(connection);
-	activeGameRooms.erase(connection);
+		activeProxies.erase(connection);
+		activeProxies.erase(otherConnection);
+	}
 
 	broadcastActiveAccounts();
 	broadcastActiveGameRooms();
@@ -226,7 +235,7 @@ void LobbyServer::onPacketReceived(const NetworkConnectionPtr & connection, cons
 	// proxy connection - no processing, only redirect
 	if(activeProxies.count(connection))
 	{
-		auto lockedPtr = activeProxies.at(connection).lock();
+		auto lockedPtr = activeProxies.at(connection);
 		if(lockedPtr)
 			return lockedPtr->sendPacket(message);
 
@@ -296,9 +305,7 @@ void LobbyServer::onPacketReceived(const NetworkConnectionPtr & connection, cons
 	if(messageType == "serverProxyLogin")
 		return receiveServerProxyLogin(connection, json);
 
-	// TODO: add logging of suspicious connections.
-	networkServer->closeConnection(connection);
-
+	connection->close();
 	logGlobal->info("(unauthorised): Unknown message type %s", messageType);
 }
 
@@ -348,13 +355,11 @@ void LobbyServer::receiveClientLogin(const NetworkConnectionPtr & connection, co
 	if(!database->isAccountIDExists(accountID))
 		return sendOperationFailed(connection, "Account not found");
 
-	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime);
+	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie);
 
 	if(clientCookieStatus == LobbyCookieStatus::INVALID)
 		return sendOperationFailed(connection, "Authentification failure");
 
-	// prolong existing cookie
-	database->updateAccessCookie(accountID, accountCookie);
 	database->updateAccountLoginTime(accountID);
 	database->setAccountOnline(accountID, true);
 
@@ -365,8 +370,8 @@ void LobbyServer::receiveClientLogin(const NetworkConnectionPtr & connection, co
 	sendLoginSuccess(connection, accountCookie, displayName);
 	sendChatHistory(connection, database->getRecentMessageHistory());
 
-	// send active accounts list to new account
-	// and update acount list to everybody else
+	// send active game rooms list to new account
+	// and update acount list to everybody else including new account
 	broadcastActiveAccounts();
 	sendMessage(connection, prepareActiveGameRooms());
 }
@@ -378,7 +383,7 @@ void LobbyServer::receiveServerLogin(const NetworkConnectionPtr & connection, co
 	std::string accountCookie = json["accountCookie"].String();
 	std::string version = json["version"].String();
 
-	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime);
+	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie);
 
 	if(clientCookieStatus == LobbyCookieStatus::INVALID)
 	{
@@ -399,7 +404,7 @@ void LobbyServer::receiveClientProxyLogin(const NetworkConnectionPtr & connectio
 	std::string accountID = json["accountID"].String();
 	std::string accountCookie = json["accountCookie"].String();
 
-	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime);
+	auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie);
 
 	if(clientCookieStatus != LobbyCookieStatus::INVALID)
 	{
@@ -424,7 +429,7 @@ void LobbyServer::receiveClientProxyLogin(const NetworkConnectionPtr & connectio
 	}
 
 	sendOperationFailed(connection, "Invalid credentials");
-	networkServer->closeConnection(connection);
+	connection->close();
 }
 
 void LobbyServer::receiveServerProxyLogin(const NetworkConnectionPtr & connection, const JsonNode & json)
@@ -456,7 +461,7 @@ void LobbyServer::receiveServerProxyLogin(const NetworkConnectionPtr & connectio
 		return;
 	}
 
-	//networkServer->closeConnection(connection);
+	//connection->close();
 }
 
 void LobbyServer::receiveOpenGameRoom(const NetworkConnectionPtr & connection, const JsonNode & json)
@@ -480,9 +485,6 @@ void LobbyServer::receiveOpenGameRoom(const NetworkConnectionPtr & connection, c
 	if(roomType == "private")
 		database->setGameRoomStatus(gameRoomID, LobbyRoomState::PRIVATE);
 
-	// TODO: additional flags / initial settings, e.g. allowCheats
-	// TODO: connection mode: direct or proxy. For now direct is assumed. Proxy might be needed later, for hosted servers
-
 	database->insertPlayerIntoGameRoom(accountID, gameRoomID);
 	broadcastActiveGameRooms();
 	sendJoinRoomSuccess(connection, gameRoomID, false);

+ 1 - 1
lobby/LobbyServer.h

@@ -29,7 +29,7 @@ class LobbyServer final : public INetworkServerListener
 	};
 
 	/// list of connected proxies. All messages received from (key) will be redirected to (value) connection
-	std::map<NetworkConnectionPtr, NetworkConnectionWeakPtr> activeProxies;
+	std::map<NetworkConnectionPtr, NetworkConnectionPtr> activeProxies;
 
 	/// list of half-established proxies from server that are still waiting for client to connect
 	std::vector<AwaitingProxyState> awaitingProxies;

+ 2 - 2
server/CVCMIServer.cpp

@@ -160,7 +160,7 @@ void CVCMIServer::onNewConnection(const std::shared_ptr<INetworkConnection> & co
 	}
 	else
 	{
-		networkServer->closeConnection(connection);
+		connection->close();
 	}
 }
 
@@ -445,7 +445,7 @@ void CVCMIServer::clientConnected(std::shared_ptr<CConnection> c, std::vector<st
 
 void CVCMIServer::clientDisconnected(std::shared_ptr<CConnection> c)
 {
-	networkServer->closeConnection(c->getConnection());
+	c->getConnection()->close();
 	vstd::erase(activeConnections, c);
 
 	if(activeConnections.empty() || hostClientId == c->connectionID)

+ 10 - 2
server/GlobalLobbyProcessor.cpp

@@ -29,7 +29,15 @@ void GlobalLobbyProcessor::establishNewConnection()
 
 void GlobalLobbyProcessor::onDisconnected(const std::shared_ptr<INetworkConnection> & connection, const std::string & errorMessage)
 {
-	throw std::runtime_error("Lost connection to a lobby server!");
+	if (connection == controlConnection)
+	{
+		throw std::runtime_error("Lost connection to a lobby server!");
+	}
+	else
+	{
+		// player disconnected
+		owner.onDisconnected(connection, errorMessage);
+	}
 }
 
 void GlobalLobbyProcessor::onPacketReceived(const std::shared_ptr<INetworkConnection> & connection, const std::vector<std::byte> & message)
@@ -47,7 +55,7 @@ void GlobalLobbyProcessor::onPacketReceived(const std::shared_ptr<INetworkConnec
 		if(json["type"].String() == "accountJoinsRoom")
 			return receiveAccountJoinsRoom(json);
 
-		throw std::runtime_error("Received unexpected message from lobby server: " + json["type"].String());
+		logGlobal->error("Received unexpected message from lobby server of type '%s' ", json["type"].String());
 	}
 	else
 	{