Browse Source

ThreadPool implementation. It runs, but dies from race conditions.

Tomasz Zieliński 2 years ago
parent
commit
a8545935c3

+ 3 - 0
cmake_modules/VCMI_lib.cmake

@@ -425,6 +425,9 @@ macro(add_main_lib TARGET_NAME LIBRARY_TYPE)
 		${MAIN_LIB_DIR}/rmg/RiverPlacer.h
 		${MAIN_LIB_DIR}/rmg/RiverPlacer.h
 		${MAIN_LIB_DIR}/rmg/TerrainPainter.h
 		${MAIN_LIB_DIR}/rmg/TerrainPainter.h
 		${MAIN_LIB_DIR}/rmg/float3.h
 		${MAIN_LIB_DIR}/rmg/float3.h
+		${MAIN_LIB_DIR}/rmg/threadpool/BlockingQueue.h
+		${MAIN_LIB_DIR}/rmg/threadpool/ThreadPool.h
+		${MAIN_LIB_DIR}/rmg/threadpool/JobProvider.h
 
 
 		${MAIN_LIB_DIR}/serializer/BinaryDeserializer.h
 		${MAIN_LIB_DIR}/serializer/BinaryDeserializer.h
 		${MAIN_LIB_DIR}/serializer/BinarySerializer.h
 		${MAIN_LIB_DIR}/serializer/BinarySerializer.h

+ 61 - 7
lib/rmg/CMapGenerator.cpp

@@ -23,6 +23,7 @@
 #include "Zone.h"
 #include "Zone.h"
 #include "Functions.h"
 #include "Functions.h"
 #include "RmgMap.h"
 #include "RmgMap.h"
+#include "threadpool/ThreadPool.h"
 #include "ObjectManager.h"
 #include "ObjectManager.h"
 #include "TreasurePlacer.h"
 #include "TreasurePlacer.h"
 #include "RoadPlacer.h"
 #include "RoadPlacer.h"
@@ -294,9 +295,12 @@ void CMapGenerator::fillZones()
 
 
 	logGlobal->info("Started filling zones");
 	logGlobal->info("Started filling zones");
 
 
+	size_t numZones = map->getZones().size();
+
 	//we need info about all town types to evaluate dwellings and pandoras with creatures properly
 	//we need info about all town types to evaluate dwellings and pandoras with creatures properly
 	//place main town in the middle
 	//place main town in the middle
-	Load::Progress::setupStepsTill(map->getZones().size(), 50);
+	
+	Load::Progress::setupStepsTill(numZones, 50);
 	for (const auto& it : map->getZones())
 	for (const auto& it : map->getZones())
 	{
 	{
 		it.second->initFreeTiles();
 		it.second->initFreeTiles();
@@ -304,16 +308,40 @@ void CMapGenerator::fillZones()
 		Progress::Progress::step();
 		Progress::Progress::step();
 	}
 	}
 
 
-	Load::Progress::setupStepsTill(map->getZones().size(), 240);
+	//TODO: multiply by the number of modificators
+	Load::Progress::setupStepsTill(numZones, 240);
 	std::vector<std::shared_ptr<Zone>> treasureZones;
 	std::vector<std::shared_ptr<Zone>> treasureZones;
-	for (const auto& it : map->getZones())
+
+	ThreadPool pool;
+
+	std::vector<boost::future<void>> futures;
+	//At most one Modificator can run for every zone
+	pool.init(std::min<int>(std::thread::hardware_concurrency(), numZones));
+
+	while (hasJobs())
 	{
 	{
-		it.second->processModificators();
+		auto job = getNextJob();
+		if (job)
+		{
+			futures.push_back(pool.async([this, job]() -> void
+				{
+					job.value()();
+					Progress::Progress::step(); //Update progress bar
+				}
+			));
+		}
+	}
 
 
+	//Wait for all the tasks
+	for (auto& fut : futures)
+	{
+		fut.get();
+	}
+
+	for (const auto& it : map->getZones())
+	{
 		if (it.second->getType() == ETemplateZoneType::TREASURE)
 		if (it.second->getType() == ETemplateZoneType::TREASURE)
 			treasureZones.push_back(it.second);
 			treasureZones.push_back(it.second);
-
-		Progress::Progress::step();
 	}
 	}
 
 
 	//find place for Grail
 	//find place for Grail
@@ -381,7 +409,7 @@ const std::vector<ArtifactID> & CMapGenerator::getAllPossibleQuestArtifacts() co
 	return questArtifacts;
 	return questArtifacts;
 }
 }
 
 
-const std::vector<HeroTypeID>& CMapGenerator::getAllPossibleHeroes() const
+const std::vector<HeroTypeID> CMapGenerator::getAllPossibleHeroes() const
 {
 {
 	//Skip heroes that were banned, including the ones placed in prisons
 	//Skip heroes that were banned, including the ones placed in prisons
 	std::vector<HeroTypeID> ret;
 	std::vector<HeroTypeID> ret;
@@ -395,11 +423,13 @@ const std::vector<HeroTypeID>& CMapGenerator::getAllPossibleHeroes() const
 
 
 void CMapGenerator::banQuestArt(const ArtifactID & id)
 void CMapGenerator::banQuestArt(const ArtifactID & id)
 {
 {
+	//TODO: Protect with mutex
 	map->map().allowedArtifact[id] = false;
 	map->map().allowedArtifact[id] = false;
 }
 }
 
 
 void CMapGenerator::banHero(const HeroTypeID & id)
 void CMapGenerator::banHero(const HeroTypeID & id)
 {
 {
+	//TODO: Protect with mutex
 	map->map().allowedHeroes[id] = false;
 	map->map().allowedHeroes[id] = false;
 }
 }
 
 
@@ -411,4 +441,28 @@ Zone * CMapGenerator::getZoneWater() const
 	return nullptr;
 	return nullptr;
 }
 }
 
 
+bool CMapGenerator::hasJobs()
+{
+	for (auto zone : map->getZones())
+	{
+		if (zone.second->hasJobs())
+		{
+			return true;
+		}
+	}
+	return false;
+}
+
+TRMGJob CMapGenerator::getNextJob()
+{
+	for (auto zone : map->getZones())
+	{
+		if (zone.second->hasJobs())
+		{
+			return zone.second->getNextJob();
+		}
+	}
+	return TRMGJob();
+}
+
 VCMI_LIB_NAMESPACE_END
 VCMI_LIB_NAMESPACE_END

+ 7 - 5
lib/rmg/CMapGenerator.h

@@ -15,6 +15,7 @@
 #include "CMapGenOptions.h"
 #include "CMapGenOptions.h"
 #include "../int3.h"
 #include "../int3.h"
 #include "CRmgTemplate.h"
 #include "CRmgTemplate.h"
+#include "threadpool/JobProvider.h"
 #include "../LoadProgress.h"
 #include "../LoadProgress.h"
 
 
 VCMI_LIB_NAMESPACE_BEGIN
 VCMI_LIB_NAMESPACE_BEGIN
@@ -30,7 +31,7 @@ class CZonePlacer;
 using JsonVector = std::vector<JsonNode>;
 using JsonVector = std::vector<JsonNode>;
 
 
 /// The map generator creates a map randomly.
 /// The map generator creates a map randomly.
-class DLL_LINKAGE CMapGenerator: public Load::Progress
+class DLL_LINKAGE CMapGenerator: public Load::Progress, public IJobProvider
 {
 {
 public:
 public:
 	struct Config
 	struct Config
@@ -64,7 +65,7 @@ public:
 	int getPrisonsRemaning() const;
 	int getPrisonsRemaning() const;
 	std::shared_ptr<CZonePlacer> getZonePlacer() const;
 	std::shared_ptr<CZonePlacer> getZonePlacer() const;
 	const std::vector<ArtifactID> & getAllPossibleQuestArtifacts() const;
 	const std::vector<ArtifactID> & getAllPossibleQuestArtifacts() const;
-	const std::vector<HeroTypeID>& getAllPossibleHeroes() const;
+	const std::vector<HeroTypeID> getAllPossibleHeroes() const;
 	void banQuestArt(const ArtifactID & id);
 	void banQuestArt(const ArtifactID & id);
 	void banHero(const HeroTypeID& id);
 	void banHero(const HeroTypeID& id);
 
 
@@ -82,11 +83,9 @@ private:
 	
 	
 	std::vector<rmg::ZoneConnection> connectionsLeft;
 	std::vector<rmg::ZoneConnection> connectionsLeft;
 	
 	
-	//std::pair<Zones::key_type, Zones::mapped_type> zoneWater;
-
 	int allowedPrisons;
 	int allowedPrisons;
 	int monolithIndex;
 	int monolithIndex;
-	std::vector<ArtifactID> questArtifacts; //TODO: Protect with mutex
+	std::vector<ArtifactID> questArtifacts;
 
 
 	/// Generation methods
 	/// Generation methods
 	void loadConfig();
 	void loadConfig();
@@ -100,6 +99,9 @@ private:
 	void genZones();
 	void genZones();
 	void fillZones();
 	void fillZones();
 
 
+	TRMGJob getNextJob() override;
+	bool hasJobs() override;
+
 };
 };
 
 
 VCMI_LIB_NAMESPACE_END
 VCMI_LIB_NAMESPACE_END

+ 24 - 0
lib/rmg/Zone.cpp

@@ -179,6 +179,30 @@ rmg::Path Zone::searchPath(const int3 & src, bool onlyStraight, const std::funct
 	return searchPath(rmg::Area({src}), onlyStraight, areafilter);
 	return searchPath(rmg::Area({src}), onlyStraight, areafilter);
 }
 }
 
 
+TRMGJob Zone::getNextJob()
+{
+	for (auto& modificator : modificators)
+	{
+		if (modificator->hasJobs())
+		{
+			return modificator->getNextJob();
+		}
+	}
+	return TRMGJob();
+}
+
+bool Zone::hasJobs()
+{
+	for (auto& modificator : modificators)
+	{
+		if (modificator->hasJobs())
+		{
+			return true;
+		}
+	}
+	return false;
+}
+
 void Zone::connectPath(const rmg::Path & path)
 void Zone::connectPath(const rmg::Path & path)
 ///connect current tile to any other free tile within zone
 ///connect current tile to any other free tile within zone
 {
 {

+ 7 - 1
lib/rmg/Zone.h

@@ -13,6 +13,7 @@
 #include "../GameConstants.h"
 #include "../GameConstants.h"
 #include "float3.h"
 #include "float3.h"
 #include "../int3.h"
 #include "../int3.h"
+#include "threadpool/JobProvider.h"
 #include "CRmgTemplate.h"
 #include "CRmgTemplate.h"
 #include "RmgArea.h"
 #include "RmgArea.h"
 #include "RmgPath.h"
 #include "RmgPath.h"
@@ -30,7 +31,7 @@ class Modificator;
 
 
 extern std::function<bool(const int3 &)> AREA_NO_FILTER;
 extern std::function<bool(const int3 &)> AREA_NO_FILTER;
 
 
-class Zone : public rmg::ZoneOptions
+class Zone : public rmg::ZoneOptions, public IJobProvider
 {
 {
 public:
 public:
 	Zone(RmgMap & map, CMapGenerator & generator);
 	Zone(RmgMap & map, CMapGenerator & generator);
@@ -63,9 +64,14 @@ public:
 	rmg::Path searchPath(const rmg::Area & src, bool onlyStraight, const std::function<bool(const int3 &)> & areafilter = AREA_NO_FILTER) const;
 	rmg::Path searchPath(const rmg::Area & src, bool onlyStraight, const std::function<bool(const int3 &)> & areafilter = AREA_NO_FILTER) const;
 	rmg::Path searchPath(const int3 & src, bool onlyStraight, const std::function<bool(const int3 &)> & areafilter = AREA_NO_FILTER) const;
 	rmg::Path searchPath(const int3 & src, bool onlyStraight, const std::function<bool(const int3 &)> & areafilter = AREA_NO_FILTER) const;
 
 
+	TRMGJob getNextJob() override;
+	bool hasJobs() override;
+
 	template<class T>
 	template<class T>
 	T* getModificator()
 	T* getModificator()
 	{
 	{
+		//TODO: Protect with recursive mutex?
+
 		for(auto & m : modificators)
 		for(auto & m : modificators)
 			if(auto * mm = dynamic_cast<T*>(m.get()))
 			if(auto * mm = dynamic_cast<T*>(m.get()))
 				return mm;
 				return mm;

+ 91 - 0
lib/rmg/threadpool/BlockingQueue.h

@@ -0,0 +1,91 @@
+/*
+ * BlockingQueue.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+
+#pragma once
+
+#include "StdInc.h"
+
+VCMI_LIB_NAMESPACE_BEGIN
+
+//Credit to https://github.com/Liam0205/toy-threadpool/tree/master/yuuki
+
+template <typename T>
+class DLL_LINKAGE BlockingQueue : protected std::queue<T>
+{
+	using WriteLock = boost::unique_lock<boost::shared_mutex>;
+	using Readlock = boost::shared_lock<boost::shared_mutex>;
+
+public:
+	BlockingQueue() = default;
+	~BlockingQueue()
+	{
+		clear();
+  	}
+	BlockingQueue(const BlockingQueue&) = delete;
+	BlockingQueue(BlockingQueue&&) = delete;
+	BlockingQueue& operator=(const BlockingQueue&) = delete;
+	BlockingQueue& operator=(BlockingQueue&&) = delete;
+
+public:
+	bool empty() const
+	{
+		Readlock lock(mx);
+		return std::queue<T>::empty();
+	}
+
+	size_t size() const
+	{
+		Readlock lock(mx);
+		return std::queue<T>::size();
+	}
+
+public:
+	void clear()
+	{
+		WriteLock lock(mx);
+		while (!std::queue<T>::empty())
+		{
+			std::queue<T>::pop();
+		}
+	}
+
+	void push(const T& obj)
+	{
+		WriteLock lock(mx);
+		std::queue<T>::push(obj);
+	}
+
+	template <typename... Args>
+	void emplace(Args&&... args)
+	{
+		WriteLock lock(mx);
+		std::queue<T>::emplace(std::forward<Args>(args)...);
+	}
+
+	bool pop(T& holder)
+	{
+		WriteLock lock(mx);
+		if (std::queue<T>::empty())
+		{
+			return false;
+		}
+		else
+		{
+			holder = std::move(std::queue<T>::front());
+			std::queue<T>::pop();
+			return true;
+		}
+	}
+
+private:
+	mutable boost::shared_mutex mx;
+};
+
+VCMI_LIB_NAMESPACE_END

+ 30 - 0
lib/rmg/threadpool/JobProvider.h

@@ -0,0 +1,30 @@
+/*
+ * JobProvider.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+
+#pragma once
+
+#include "StdInc.h"
+#include "../../GameConstants.h"
+
+VCMI_LIB_NAMESPACE_BEGIN
+
+typedef std::function<void()> TRMGfunction ;
+typedef std::optional<TRMGfunction> TRMGJob;
+
+class DLL_LINKAGE IJobProvider
+{
+public:
+	//TODO: Think about some mutex protection
+
+	virtual TRMGJob getNextJob() = 0;
+	virtual bool hasJobs() = 0;
+};
+
+VCMI_LIB_NAMESPACE_END

+ 0 - 0
lib/rmg/threadpool/JobProvoider.cpp


+ 190 - 0
lib/rmg/threadpool/ThreadPool.h

@@ -0,0 +1,190 @@
+/*
+ * ThreadPool.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+
+#pragma once
+
+#include "BlockingQueue.h"
+#include "JobProvider.h"
+#include <boost/thread/future.hpp>
+#include <boost/thread/condition_variable.hpp>
+
+VCMI_LIB_NAMESPACE_BEGIN
+
+//Credit to https://github.com/Liam0205/toy-threadpool/tree/master/yuuki
+
+class DLL_LINKAGE ThreadPool
+{
+private:
+	using Lock = boost::unique_lock<boost::shared_mutex>;
+	mutable boost::shared_mutex mx;
+	mutable boost::condition_variable_any cv;
+	mutable boost::once_flag once;
+
+	bool isInitialized = false;
+	bool stopping = false;
+	bool canceling = false;
+public:
+	ThreadPool();
+	~ThreadPool();
+
+	void init(size_t numThreads);
+	void spawn();
+	void terminate();
+	void cancel();
+
+public:
+	bool initialized() const;
+	bool running() const;
+	int size() const;
+private:
+	bool isRunning() const;
+
+public:
+	auto async(std::function<void()>&& f) const -> boost::future<void>;
+
+private:
+	std::vector<boost::thread> workers;
+	mutable BlockingQueue<TRMGfunction> tasks;
+};
+
+ThreadPool::ThreadPool() :
+	once(BOOST_ONCE_INIT)
+{};
+
+ThreadPool::~ThreadPool()
+{
+	terminate();
+}
+
+inline void ThreadPool::init(size_t numThreads)
+{
+	boost::call_once(once, [this, numThreads]()
+	{
+		Lock lock(mx);
+		stopping = false;
+		canceling = false;
+		workers.reserve(numThreads);
+		for (size_t i = 0; i < numThreads; ++i)
+		{
+			workers.emplace_back(std::bind(&ThreadPool::spawn, this));
+		}
+		isInitialized = true;
+	});
+}
+
+bool ThreadPool::isRunning() const
+{
+	return isInitialized && !stopping && !canceling;
+}
+
+inline bool ThreadPool::initialized() const
+{
+	Lock lock(mx);
+	return isInitialized;
+}
+
+inline bool ThreadPool::running() const
+{
+	Lock lock(mx);
+	return isRunning();
+}
+
+inline int ThreadPool::size() const
+{
+	Lock lock(mx);
+	return workers.size();
+}
+
+inline void ThreadPool::spawn()
+{
+	while(true)
+	{
+		bool pop = false;
+		TRMGfunction task;
+		{
+			Lock lock(mx);
+			cv.wait(lock, [this, &pop, &task]
+			{
+				pop = tasks.pop(task);
+				return canceling || stopping || pop;
+			});
+		}
+		if (canceling || (stopping && !pop))
+		{
+			return;
+		}
+		task();
+	}
+}
+
+inline void ThreadPool::terminate()
+{
+	{
+		Lock lock(mx);
+		if (running())
+		{
+			stopping = true;
+		}
+		else
+		{
+			return;
+		}
+	}
+	cv.notify_all();
+	for (auto& worker : workers)
+	{
+		worker.join();
+	}
+}
+
+inline void ThreadPool::cancel()
+{
+	{
+		Lock lock(mx);
+		if (running())
+		{
+			canceling = true;
+		}
+		else
+		{
+			return;
+		}
+	}
+	tasks.clear();
+	cv.notify_all();
+	for (auto& worker : workers)
+	{
+		worker.join();
+	}
+}
+
+auto ThreadPool::async(std::function<void()>&& f) const -> boost::future<void>
+{
+    using TaskT = boost::packaged_task<void>;
+
+    {
+        Lock lock(mx);
+        if (stopping || canceling)
+        {
+            throw std::runtime_error("Delegating task to a threadpool that has been terminated or canceled.");
+        }
+    }
+
+    std::shared_ptr<TaskT> task = std::make_shared<TaskT>(f);
+    boost::future<void> fut = task->get_future();
+    tasks.emplace([task]() -> void
+    {
+        (*task)();
+    });
+    cv.notify_one();
+    return fut;
+}
+
+VCMI_LIB_NAMESPACE_END