瀏覽代碼

fix nullptr thread_local gw and cc: introduce SET_GLOBAL_STATE_TBB for tbb:parallel_for

Mircea TheHonestCTO 3 月之前
父節點
當前提交
1b0b15063f

+ 1 - 29
AI/Nullkiller2/AIGateway.cpp

@@ -39,35 +39,7 @@
 namespace NK2AI
 {
 
-//one thread may be turn of AI and another will be handling a side effect for AI2
-thread_local CCallback * ccTl = nullptr;
-thread_local AIGateway * aiGwTl = nullptr;
-
-//helper RAII to manage global ai/cb ptrs
-struct SetGlobalState
-{
-	SetGlobalState(AIGateway * aiGw)
-	{
-		assert(!aiGwTl);
-		assert(!ccTl);
-
-		aiGwTl = aiGw;
-		ccTl = aiGw->cc.get();
-	}
-	~SetGlobalState()
-	{
-		//TODO: how to handle rm? shouldn't be called after ai is destroyed, hopefully
-		//TODO: to ensure that, make rm unique_ptr
-		aiGwTl = nullptr;
-		ccTl = nullptr;
-	}
-};
-
-
-#define SET_GLOBAL_STATE(aiGw) SetGlobalState _hlpSetState(aiGw)
-
 #define NET_EVENT_HANDLER SET_GLOBAL_STATE(this)
-#define MAKING_TURN SET_GLOBAL_STATE(this)
 
 AIGateway::AIGateway()
 	:status(this)
@@ -845,7 +817,7 @@ bool AIGateway::makePossibleUpgrades(const CArmedInstance * obj)
 
 void AIGateway::makeTurn()
 {
-	MAKING_TURN;
+	SET_GLOBAL_STATE(this);
 
 	auto day = cc->getDate(Date::DAY);
 	logAi->info("Player %d (%s) starting turn, day %d", playerID, playerID.toString(), day);

+ 42 - 0
AI/Nullkiller2/AIGateway.h

@@ -28,6 +28,48 @@ VCMI_LIB_NAMESPACE_END
 namespace NK2AI
 {
 
+// one thread may be turn of AI and another will be handling a side effect for AI2
+inline thread_local CCallback * ccTl = nullptr;
+inline thread_local AIGateway * aiGwTl = nullptr;
+
+// helper RAII to manage global ai/cb ptrs
+struct SetGlobalState
+{
+	AIGateway * previousAiGw;
+	CCallback * previousCc;
+	bool wasAlreadySet;
+
+	SetGlobalState(AIGateway * aiGw, CCallback * cc)
+		: previousAiGw(aiGwTl), previousCc(ccTl), wasAlreadySet(aiGwTl != nullptr)
+	{
+		aiGwTl = aiGw;
+		ccTl = cc;
+#if NK2AI_TRACE_LEVEL >= 2
+		if(wasAlreadySet)
+		{
+			logAi->trace("SetGlobalState constructed (was already set)");
+		}
+		else
+		{
+			logAi->trace("SetGlobalState constructed");
+		}
+#endif
+	}
+
+	~SetGlobalState()
+	{
+		// Restore previous values instead of always setting to nullptr
+		aiGwTl = previousAiGw;
+		ccTl = previousCc;
+#if NK2AI_TRACE_LEVEL >= 2
+		logAi->trace("SetGlobalState destroyed");
+#endif
+	}
+};
+
+#define SET_GLOBAL_STATE(aiGw) SetGlobalState _hlpSetState(aiGw, aiGw->cc.get())
+#define SET_GLOBAL_STATE_TBB(aiGw) SET_GLOBAL_STATE(aiGw)
+
 class AIStatus
 {
 	AIGateway * aiGw;

+ 13 - 14
AI/Nullkiller2/Analyzers/ObjectClusterizer.cpp

@@ -332,22 +332,21 @@ void ObjectClusterizer::clusterize()
 			aiNk->memory->visitableObjs.end());
 	}
 
-#if NK2AI_TRACE_LEVEL == 0
-	tbb::parallel_for(tbb::blocked_range<size_t>(0, objs.size()), [&](const tbb::blocked_range<size_t> & r) {
-#else
-	tbb::blocked_range<size_t> r(0, objs.size());
-#endif
-		auto priorityEvaluator = aiNk->priorityEvaluators->acquire();
-		auto heroes = aiNk->cc->getHeroesInfo();
-		std::vector<AIPath> pathCache;
-
-		for(int i = r.begin(); i != r.end(); i++)
+	tbb::parallel_for(
+		tbb::blocked_range<size_t>(0, objs.size()),
+		[&](const tbb::blocked_range<size_t> & r)
 		{
-			clusterizeObject(objs[i], priorityEvaluator.get(), pathCache, heroes);
+			SET_GLOBAL_STATE_TBB(aiNk->aiGw);
+			auto priorityEvaluator = aiNk->priorityEvaluators->acquire();
+			auto heroes = aiNk->cc->getHeroesInfo();
+			std::vector<AIPath> pathCache;
+
+			for(int i = r.begin(); i != r.end(); i++)
+			{
+				clusterizeObject(objs[i], priorityEvaluator.get(), pathCache, heroes);
+			}
 		}
-#if NK2AI_TRACE_LEVEL == 0
-	});
-#endif
+	);
 
 	logAi->trace("Near objects count: %i", nearObjects.objects.size());
 	logAi->trace("Far objects count: %i", farObjects.objects.size());

+ 3 - 2
AI/Nullkiller2/Behaviors/CaptureObjectsBehavior.cpp

@@ -182,11 +182,11 @@ void CaptureObjectsBehavior::decomposeObjects(
 
 	logAi->debug("Scanning objects, count %d", objs.size());
 
-	// tbb::blocked_range<size_t> r(0, objs.size());
 	tbb::parallel_for(
 		tbb::blocked_range<size_t>(0, objs.size()),
 		[this, &objs, &sync, &result, nullkiller](const tbb::blocked_range<size_t> & r)
 		{
+			SET_GLOBAL_STATE_TBB(nullkiller->aiGw);
 			std::vector<AIPath> paths;
 			Goals::TGoalVec tasksLocal;
 
@@ -211,7 +211,8 @@ void CaptureObjectsBehavior::decomposeObjects(
 
 			std::lock_guard lock(sync); // FIXME: consider using tbb::parallel_reduce instead to avoid mutex overhead
 			vstd::concatenate(result, tasksLocal);
-		});
+		}
+	);
 }
 
 Goals::TGoalVec CaptureObjectsBehavior::decompose(const Nullkiller * aiNk) const

+ 13 - 9
AI/Nullkiller2/Engine/Nullkiller.cpp

@@ -32,6 +32,9 @@
 namespace NK2AI
 {
 
+// extern thread_local CCallback * ccTl;
+// extern thread_local AIGateway * aiGwTl;
+
 using namespace Goals;
 
 // while we play vcmieagles graph can be shared
@@ -189,16 +192,17 @@ Goals::TTaskVec Nullkiller::buildPlan(TGoalVec & tasks, int priorityTier) const
 	TaskPlan taskPlan;
 
 	tbb::parallel_for(tbb::blocked_range<size_t>(0, tasks.size()), [this, &tasks, priorityTier](const tbb::blocked_range<size_t> & r)
-		{
-			auto evaluator = this->priorityEvaluators->acquire();
+	{
+		SET_GLOBAL_STATE_TBB(this->aiGw);
+		auto evaluator = this->priorityEvaluators->acquire();
 
-			for(size_t i = r.begin(); i != r.end(); i++)
-			{
-				const auto & task = tasks[i];
-				if (task->asTask()->priority <= 0 || priorityTier != PriorityEvaluator::PriorityTier::BUILDINGS)
-					task->asTask()->priority = evaluator->evaluate(task, priorityTier);
-			}
-		});
+		for(size_t i = r.begin(); i != r.end(); i++)
+		{
+			const auto & task = tasks[i];
+			if(task->asTask()->priority <= 0 || priorityTier != PriorityEvaluator::PriorityTier::BUILDINGS)
+				task->asTask()->priority = evaluator->evaluate(task, priorityTier);
+		}
+	});
 
 	boost::range::sort(tasks, [](const TSubgoal& g1, const TSubgoal& g2) -> bool
 		{

+ 1 - 1
AI/Nullkiller2/Engine/Nullkiller.h

@@ -84,7 +84,6 @@ private:
 	ScanDepth scanDepth;
 	TResources lockedResources;
 	bool useHeroChain;
-	AIGateway * aiGw;
 	bool openMap;
 	bool useObjectGraph;
 	bool pathfinderInvalidated;
@@ -107,6 +106,7 @@ public:
 	std::unique_ptr<Settings> settings;
 	/// Same value as AIGateway->playerID
 	PlayerColor playerID;
+	AIGateway * aiGw;
 	/// Same value as AIGateway->cc
 	std::shared_ptr<CCallback> cc;
 	std::mutex aiStateMutex;

+ 11 - 6
AI/Nullkiller2/Pathfinding/AINodeStorage.cpp

@@ -9,18 +9,20 @@
 */
 #include "StdInc.h"
 #include "AINodeStorage.h"
-#include "Actions/TownPortalAction.h"
-#include "Actions/WhirlpoolAction.h"
-#include "../Engine/Nullkiller.h"
+
+#include "../../../lib/CPlayerState.h"
+#include "../../../lib/IGameSettings.h"
 #include "../../../lib/callback/IGameInfoCallback.h"
 #include "../../../lib/mapping/CMap.h"
 #include "../../../lib/pathfinder/CPathfinder.h"
-#include "../../../lib/pathfinder/PathfinderUtil.h"
 #include "../../../lib/pathfinder/PathfinderOptions.h"
+#include "../../../lib/pathfinder/PathfinderUtil.h"
 #include "../../../lib/spells/ISpellMechanics.h"
 #include "../../../lib/spells/adventure/TownPortalEffect.h"
-#include "../../../lib/IGameSettings.h"
-#include "../../../lib/CPlayerState.h"
+#include "../Engine/Nullkiller.h"
+#include "../AIGateway.h"
+#include "Actions/TownPortalAction.h"
+#include "Actions/WhirlpoolAction.h"
 
 namespace NK2AI
 {
@@ -127,6 +129,7 @@ void AINodeStorage::initialize(const PathfinderOptions & options, const IGameInf
 
 	tbb::parallel_for(tbb::blocked_range<size_t>(0, sizes.x), [&](const tbb::blocked_range<size_t>& r)
 	{
+		SET_GLOBAL_STATE_TBB(aiNk->aiGw);
 		int3 pos;
 
 		for(pos.z = 0; pos.z < sizes.z; ++pos.z)
@@ -594,6 +597,7 @@ bool AINodeStorage::calculateHeroChain()
 
 	tbb::parallel_for(tbb::blocked_range<size_t>(0, data.size()), [&](const tbb::blocked_range<size_t>& r)
 	{
+		SET_GLOBAL_STATE_TBB(aiNk->aiGw);
 		HeroChainCalculationTask task(*this, data, chainMask, heroChainTurn);
 
 		int ourThread = tbb::this_task_arena::current_thread_index();
@@ -1252,6 +1256,7 @@ void AINodeStorage::calculateTownPortalTeleportations(std::vector<CGPathNode *>
 	{
 		tbb::parallel_for(tbb::blocked_range<size_t>(0, actorsVector.size()), [&](const tbb::blocked_range<size_t> & r)
 			{
+				SET_GLOBAL_STATE_TBB(aiNk->aiGw);
 				for(int i = r.begin(); i != r.end(); i++)
 				{
 					calculateTownPortal(actorsVector[i], maskMap, initialNodes, output);

+ 5 - 1
AI/Nullkiller2/Pathfinding/AIPathfinder.cpp

@@ -9,9 +9,11 @@
 */
 #include "StdInc.h"
 #include "AIPathfinder.h"
-#include "AIPathfinderConfig.h"
+
 #include "../../../lib/mapping/CMap.h"
 #include "../Engine/Nullkiller.h"
+#include "../AIGateway.h"
+#include "AIPathfinderConfig.h"
 
 namespace NK2AI
 {
@@ -164,6 +166,8 @@ void AIPathfinder::updateGraphs(
 
 	tbb::parallel_for(tbb::blocked_range<size_t>(0, heroesVector.size()), [this, &heroesVector, &heroes, mainScanDepth, scoutScanDepth](const tbb::blocked_range<size_t> & r)
 		{
+			SET_GLOBAL_STATE_TBB(aiNk->aiGw);
+
 			for(auto i = r.begin(); i != r.end(); i++)
 			{
 				auto role = heroes.at(heroesVector[i]);

+ 32 - 29
AI/Nullkiller2/pforeach.h

@@ -1,29 +1,31 @@
 #pragma once
 
+#include "AIGateway.h"
 #include "Engine/Nullkiller.h"
 
 namespace NK2AI
 {
 
-template<typename TFunc>
-void pforeachTilePos(const int3 & mapSize, TFunc fn)
-{
-	for(int z = 0; z < mapSize.z; ++z)
-	{
-		tbb::parallel_for(tbb::blocked_range<size_t>(0, mapSize.x), [&](const tbb::blocked_range<size_t> & r)
-			{
-				int3 pos(0, 0, z);
-
-				for(pos.x = r.begin(); pos.x != r.end(); ++pos.x)
-				{
-					for(pos.y = 0; pos.y < mapSize.y; ++pos.y)
-					{
-						fn(pos);
-					}
-				}
-			});
-	}
-}
+// template<typename TFunc>
+// void pforeachTilePos(const int3 & mapSize, TFunc fn)
+// {
+// 	for(int z = 0; z < mapSize.z; ++z)
+// 	{
+// 		tbb::parallel_for(tbb::blocked_range<size_t>(0, mapSize.x), [&](const tbb::blocked_range<size_t> & r)
+// 			{
+//				SET_GLOBAL_STATE_TBB(this->aiGw);
+// 				int3 pos(0, 0, z);
+//
+// 				for(pos.x = r.begin(); pos.x != r.end(); ++pos.x)
+// 				{
+// 					for(pos.y = 0; pos.y < mapSize.y; ++pos.y)
+// 					{
+// 						fn(pos);
+// 					}
+// 				}
+// 			});
+// 	}
+// }
 
 template<typename TFunc>
 void pforeachTilePaths(const int3 & mapSize, const Nullkiller * aiNk, TFunc fn)
@@ -31,19 +33,20 @@ void pforeachTilePaths(const int3 & mapSize, const Nullkiller * aiNk, TFunc fn)
 	for(int z = 0; z < mapSize.z; ++z)
 	{
 		tbb::parallel_for(tbb::blocked_range<size_t>(0, mapSize.x), [&](const tbb::blocked_range<size_t> & r)
-			{
-				int3 pos(0, 0, z);
-				std::vector<AIPath> paths;
+		{
+			SET_GLOBAL_STATE_TBB(aiNk->aiGw);
+			int3 pos(0, 0, z);
+			std::vector<AIPath> paths;
 
-				for(pos.x = r.begin(); pos.x != r.end(); ++pos.x)
+			for(pos.x = r.begin(); pos.x != r.end(); ++pos.x)
+			{
+				for(pos.y = 0; pos.y < mapSize.y; ++pos.y)
 				{
-					for(pos.y = 0; pos.y < mapSize.y; ++pos.y)
-					{
-						aiNk->pathfinder->calculatePathInfo(paths, pos);
-						fn(pos, paths);
-					}
+					aiNk->pathfinder->calculatePathInfo(paths, pos);
+					fn(pos, paths);
 				}
-			});
+			}
+		});
 	}
 }