瀏覽代碼

Always track already serialized pointers to avoid infinite recursion on
sending complex objects

Ivan Savenko 1 年之前
父節點
當前提交
31738e8f90
共有 4 個文件被更改,包括 58 次插入55 次删除
  1. 0 3
      lib/networkPacks/PacksForLobby.h
  2. 33 28
      lib/serializer/BinaryDeserializer.h
  3. 22 22
      lib/serializer/BinarySerializer.h
  4. 3 2
      lib/serializer/Connection.cpp

+ 0 - 3
lib/networkPacks/PacksForLobby.h

@@ -157,10 +157,7 @@ struct DLL_LINKAGE LobbyStartGame : public CLobbyPackToPropagate
 	{
 		h & clientId;
 		h & initializedStartInfo;
-		bool sps = h.smartPointerSerialization;
-		h.smartPointerSerialization = true;
 		h & initializedGameState;
-		h.smartPointerSerialization = sps;
 	}
 };
 

+ 33 - 28
lib/serializer/BinaryDeserializer.h

@@ -37,39 +37,44 @@ public:
 /// Effectively revesed version of BinarySerializer
 class BinaryDeserializer : public CLoaderBase
 {
-	template<typename Ser,typename T>
-	struct LoadIfStackInstance
+	template<typename Fake, typename T>
+	static bool loadIfStackInstance(T &data)
 	{
-		static bool invoke(Ser &s, T &data)
-		{
+		return false;
+	}
+
+	template<typename Fake>
+	bool loadIfStackInstance(const CStackInstance* &data)
+	{
+		CArmedInstance * armyPtr = nullptr;
+		ObjectInstanceID armyID;
+		SlotID slot;
+		load(armyID);
+		load(slot);
+
+		if (armyID == ObjectInstanceID::NONE)
 			return false;
+
+		if(reader->smartVectorMembersSerialization)
+		{
+			if(const auto *info = reader->getVectorizedTypeInfo<CArmedInstance, ObjectInstanceID>())
+				armyPtr = static_cast<CArmedInstance *>(reader->getVectorItemFromId<CArmedInstance, ObjectInstanceID>(*info, armyID));
 		}
-	};
 
-	template<typename Ser>
-	struct LoadIfStackInstance<Ser, CStackInstance *>
-	{
-		static bool invoke(Ser &s, CStackInstance* &data)
+		if(slot != SlotID::COMMANDER_SLOT_PLACEHOLDER)
 		{
-			CArmedInstance *armedObj;
-			SlotID slot;
-			s.load(armedObj);
-			s.load(slot);
-			if(slot != SlotID::COMMANDER_SLOT_PLACEHOLDER)
-			{
-				assert(armedObj->hasStackAtSlot(slot));
-				data = armedObj->stacks[slot];
-			}
-			else
-			{
-				auto * hero = dynamic_cast<CGHeroInstance *>(armedObj);
-				assert(hero);
-				assert(hero->commander);
-				data = hero->commander;
-			}
-			return true;
+			assert(armyPtr->hasStackAtSlot(slot));
+			data = armyPtr->stacks[slot];
 		}
-	};
+		else
+		{
+			auto * hero = dynamic_cast<CGHeroInstance *>(armyPtr);
+			assert(hero);
+			assert(hero->commander);
+			data = hero->commander;
+		}
+		return true;
+	}
 
 	template <typename T, typename Enable = void>
 	struct ClassObjectCreator
@@ -331,7 +336,7 @@ public:
 
 		if(reader->sendStackInstanceByIds)
 		{
-			bool gotLoaded = LoadIfStackInstance<BinaryDeserializer,T>::invoke(* this, data);
+			bool gotLoaded = loadIfStackInstance<void>(data);
 			if(gotLoaded)
 				return;
 		}

+ 22 - 22
lib/serializer/BinarySerializer.h

@@ -52,33 +52,33 @@ class BinarySerializer : public CSaverBase
 		}
 	};
 
-	template<typename Ser,typename T>
-	struct SaveIfStackInstance
+	template<typename Fake, typename T>
+	bool saveIfStackInstance(const T &data)
 	{
-		static bool invoke(Ser &s, const T &data)
-		{
-			return false;
-		}
-	};
+		return false;
+	}
 
-	template<typename Ser>
-	struct SaveIfStackInstance<Ser, CStackInstance *>
+	template<typename Fake>
+	bool saveIfStackInstance(const CStackInstance* const &data)
 	{
-		static bool invoke(Ser &s, const CStackInstance* const &data)
-		{
-			assert(data->armyObj);
-			SlotID slot;
+		assert(data->armyObj);
 
-			if(data->getNodeType() == CBonusSystemNode::COMMANDER)
-				slot = SlotID::COMMANDER_SLOT_PLACEHOLDER;
-			else
-				slot = data->armyObj->findStack(data);
+		SlotID slot;
 
-			assert(slot != SlotID());
-			s & data->armyObj & slot;
+		if(data->getNodeType() == CBonusSystemNode::COMMANDER)
+			slot = SlotID::COMMANDER_SLOT_PLACEHOLDER;
+		else
+			slot = data->armyObj->findStack(data);
+
+		assert(slot != SlotID());
+		save(data->armyObj->id);
+		save(slot);
+
+		if (data->armyObj->id != ObjectInstanceID::NONE)
 			return true;
-		}
-	};
+		else
+			return false;
+	}
 
 	template <typename T> class CPointerSaver;
 
@@ -252,7 +252,7 @@ public:
 
 		if(writer->sendStackInstanceByIds)
 		{
-			const bool gotSaved = SaveIfStackInstance<BinarySerializer,T>::invoke(*this, data);
+			const bool gotSaved = saveIfStackInstance<void>(data);
 			if(gotSaved)
 				return;
 		}

+ 3 - 2
lib/serializer/Connection.cpp

@@ -84,6 +84,7 @@ void CConnection::sendPack(const CPack * pack)
 
 	connectionPtr->sendPacket(packWriter->buffer);
 	packWriter->buffer.clear();
+	serializer->savedPointers.clear();
 }
 
 CPack * CConnection::retrievePack(const std::vector<std::byte> & data)
@@ -102,6 +103,8 @@ CPack * CConnection::retrievePack(const std::vector<std::byte> & data)
 		throw std::runtime_error("Failed to retrieve pack! Not all data has been read!");
 
 	logNetwork->trace("Received CPack of type %s", typeid(*result).name());
+	deserializer->loadedPointers.clear();
+	deserializer->loadedSharedPointers.clear();
 	return result;
 }
 
@@ -132,7 +135,6 @@ void CConnection::enterLobbyConnectionMode()
 	deserializer->loadedPointers.clear();
 	serializer->savedPointers.clear();
 	disableSmartVectorMemberSerialization();
-	disableSmartPointerSerialization();
 	disableStackSendingByID();
 }
 
@@ -144,7 +146,6 @@ void CConnection::setCallback(IGameCallback * cb)
 void CConnection::enterGameplayConnectionMode(CGameState * gs)
 {
 	enableStackSendingByID();
-	disableSmartPointerSerialization();
 
 	setCallback(gs->callback);
 	enableSmartVectorMemberSerializatoin(gs);