Browse Source

NKAI: parallel capture objects

Andrii Danylchenko 1 năm trước cách đây
mục cha
commit
d6f1a5c2b3

+ 4 - 4
AI/Nullkiller/AIUtility.cpp

@@ -171,7 +171,7 @@ bool CDistanceSorter::operator()(const CGObjectInstance * lhs, const CGObjectIns
 	return ln->getCost() < rn->getCost();
 }
 
-bool isSafeToVisit(HeroPtr h, const CCreatureSet * heroArmy, uint64_t dangerStrength)
+bool isSafeToVisit(const CGHeroInstance * h, const CCreatureSet * heroArmy, uint64_t dangerStrength)
 {
 	const ui64 heroStrength = h->getFightingStrength() * heroArmy->getArmyStrength();
 
@@ -183,9 +183,9 @@ bool isSafeToVisit(HeroPtr h, const CCreatureSet * heroArmy, uint64_t dangerStre
 	return true; //there's no danger
 }
 
-bool isSafeToVisit(HeroPtr h, uint64_t dangerStrength)
+bool isSafeToVisit(const CGHeroInstance * h, uint64_t dangerStrength)
 {
-	return isSafeToVisit(h, h.get(), dangerStrength);
+	return isSafeToVisit(h, h, dangerStrength);
 }
 
 bool isObjectRemovable(const CGObjectInstance * obj)
@@ -285,7 +285,7 @@ creInfo infoFromDC(const dwellingContent & dc)
 	return ci;
 }
 
-bool compareHeroStrength(HeroPtr h1, HeroPtr h2)
+bool compareHeroStrength(const CGHeroInstance * h1, const CGHeroInstance * h2)
 {
 	return h1->getTotalStrength() < h2->getTotalStrength();
 }

+ 3 - 3
AI/Nullkiller/AIUtility.h

@@ -232,10 +232,10 @@ bool isBlockVisitObj(const int3 & pos);
 bool isWeeklyRevisitable(const CGObjectInstance * obj);
 
 bool isObjectRemovable(const CGObjectInstance * obj); //FIXME FIXME: move logic to object property!
-bool isSafeToVisit(HeroPtr h, uint64_t dangerStrength);
-bool isSafeToVisit(HeroPtr h, const CCreatureSet *, uint64_t dangerStrength);
+bool isSafeToVisit(const CGHeroInstance * h, uint64_t dangerStrength);
+bool isSafeToVisit(const CGHeroInstance * h, const CCreatureSet *, uint64_t dangerStrength);
 
-bool compareHeroStrength(HeroPtr h1, HeroPtr h2);
+bool compareHeroStrength(const CGHeroInstance * h1, const CGHeroInstance * h2);
 bool compareArmyStrength(const CArmedInstance * a1, const CArmedInstance * a2);
 bool compareArtifacts(const CArtifactInstance * a1, const CArtifactInstance * a2);
 bool townHasFreeTavern(const CGTownInstance * town);

+ 49 - 34
AI/Nullkiller/Behaviors/CaptureObjectsBehavior.cpp

@@ -47,7 +47,7 @@ bool CaptureObjectsBehavior::operator==(const CaptureObjectsBehavior & other) co
 		&& vectorEquals(objectSubTypes, other.objectSubTypes);
 }
 
-Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath> & paths, const CGObjectInstance * objToVisit)
+Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath> & paths, Nullkiller * nullkiller, const CGObjectInstance * objToVisit)
 {
 	Goals::TGoalVec tasks;
 
@@ -64,7 +64,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 		logAi->trace("Path found %s", path.toString());
 #endif
 
-		if(ai->nullkiller->dangerHitMap->enemyCanKillOurHeroesAlongThePath(path))
+		if(nullkiller->dangerHitMap->enemyCanKillOurHeroesAlongThePath(path))
 		{
 #if NKAI_TRACE_LEVEL >= 2
 			logAi->trace("Ignore path. Target hero can be killed by enemy. Our power %lld", path.getHeroStrength());
@@ -72,7 +72,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 			continue;
 		}
 
-		if(objToVisit && !shouldVisit(ai->nullkiller.get(), path.targetHero, objToVisit))
+		if(objToVisit && !shouldVisit(nullkiller, path.targetHero, objToVisit))
 		{
 #if NKAI_TRACE_LEVEL >= 2
 			logAi->trace("Ignore path. Hero %s should not visit obj %s", path.targetHero->getNameTranslated(), objToVisit->getObjectName());
@@ -83,7 +83,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 		auto hero = path.targetHero;
 		auto danger = path.getTotalDanger();
 
-		if(ai->nullkiller->heroManager->getHeroRole(hero) == HeroRole::SCOUT
+		if(nullkiller->heroManager->getHeroRole(hero) == HeroRole::SCOUT
 			&& (path.getTotalDanger() == 0 || path.turn() > 0)
 			&& path.exchangeCount > 1)
 		{
@@ -135,7 +135,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 
 			sharedPtr.reset(newWay);
 
-			auto heroRole = ai->nullkiller->heroManager->getHeroRole(path.targetHero);
+			auto heroRole = nullkiller->heroManager->getHeroRole(path.targetHero);
 
 			auto & closestWay = closestWaysByRole[heroRole];
 
@@ -144,7 +144,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 				closestWay = &path;
 			}
 
-			if(!ai->nullkiller->arePathHeroesLocked(path))
+			if(!nullkiller->arePathHeroesLocked(path))
 			{
 				waysToVisitObj.push_back(newWay);
 				tasks[tasks.size() - 1] = sharedPtr;
@@ -154,7 +154,7 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 
 	for(auto way : waysToVisitObj)
 	{
-		auto heroRole = ai->nullkiller->heroManager->getHeroRole(way->getPath().targetHero);
+		auto heroRole = nullkiller->heroManager->getHeroRole(way->getPath().targetHero);
 		auto closestWay = closestWaysByRole[heroRole];
 
 		if(closestWay)
@@ -170,8 +170,11 @@ Goals::TGoalVec CaptureObjectsBehavior::getVisitGoals(const std::vector<AIPath>
 Goals::TGoalVec CaptureObjectsBehavior::decompose() const
 {
 	Goals::TGoalVec tasks;
+	std::mutex sync;
 
-	auto captureObjects = [&](const std::vector<const CGObjectInstance*> & objs) -> void
+	Nullkiller * nullkiller = ai->nullkiller.get();
+
+	auto captureObjects = [&](const std::vector<const CGObjectInstance *> & objs) -> void
 	{
 		if(objs.empty())
 		{
@@ -180,36 +183,48 @@ Goals::TGoalVec CaptureObjectsBehavior::decompose() const
 
 		logAi->debug("Scanning objects, count %d", objs.size());
 
-		std::vector<AIPath> paths;
+		tbb::parallel_for(
+			tbb::blocked_range<size_t>(0, objs.size()),
+			[this, &objs, &sync, &tasks, nullkiller](const tbb::blocked_range<size_t> & r)
+			{
+				std::vector<AIPath> paths;
+				Goals::TGoalVec tasksLocal;
+
+				for(auto i = r.begin(); i != r.end(); i++)
+				{
+					auto objToVisit = objs[i];
+
+					if(!objectMatchesFilter(objToVisit))
+						continue;
 
-		for(auto objToVisit : objs)
-		{
-			if(!objectMatchesFilter(objToVisit))
-				continue;
-	
 #if NKAI_TRACE_LEVEL >= 1
-			logAi->trace("Checking object %s, %s", objToVisit->getObjectName(), objToVisit->visitablePos().toString());
+					logAi->trace("Checking object %s, %s", objToVisit->getObjectName(), objToVisit->visitablePos().toString());
 #endif
 
-			const int3 pos = objToVisit->visitablePos();
-			bool useObjectGraph = ai->nullkiller->settings->isObjectGraphAllowed()
-				&& ai->nullkiller->getScanDepth() != ScanDepth::SMALL;
+					const int3 pos = objToVisit->visitablePos();
+					bool useObjectGraph = nullkiller->settings->isObjectGraphAllowed()
+						&& nullkiller->getScanDepth() != ScanDepth::SMALL;
 
-			ai->nullkiller->pathfinder->calculatePathInfo(paths, pos, useObjectGraph);
+					nullkiller->pathfinder->calculatePathInfo(paths, pos, useObjectGraph);
+
+					std::vector<std::shared_ptr<ExecuteHeroChain>> waysToVisitObj;
+					std::shared_ptr<ExecuteHeroChain> closestWay;
 
-			std::vector<std::shared_ptr<ExecuteHeroChain>> waysToVisitObj;
-			std::shared_ptr<ExecuteHeroChain> closestWay;
-					
 #if NKAI_TRACE_LEVEL >= 1
-			logAi->trace("Found %d paths", paths.size());
+					logAi->trace("Found %d paths", paths.size());
 #endif
-			vstd::concatenate(tasks, getVisitGoals(paths, objToVisit));
-		}
+					vstd::concatenate(tasksLocal, getVisitGoals(paths, nullkiller, objToVisit));
+				}
 
-		vstd::erase_if(tasks, [](TSubgoal task) -> bool
-		{
-			return task->invalid();
-		});
+				vstd::erase_if(tasksLocal, [](TSubgoal task) -> bool
+				{
+					return task->invalid();
+				});
+
+				std::lock_guard<std::mutex> lock(sync);
+
+				vstd::concatenate(tasks, tasksLocal);
+			});
 	};
 
 	if(specificObjects)
@@ -220,15 +235,15 @@ Goals::TGoalVec CaptureObjectsBehavior::decompose() const
 	{
 		captureObjects(
 			std::vector<const CGObjectInstance *>(
-				ai->nullkiller->memory->visitableObjs.begin(),
-				ai->nullkiller->memory->visitableObjs.end()));
+				nullkiller->memory->visitableObjs.begin(),
+				nullkiller->memory->visitableObjs.end()));
 	}
 	else
 	{
-		captureObjects(ai->nullkiller->objectClusterizer->getNearbyObjects());
+		captureObjects(nullkiller->objectClusterizer->getNearbyObjects());
 
-		if(tasks.empty() || ai->nullkiller->getScanDepth() != ScanDepth::SMALL)
-			captureObjects(ai->nullkiller->objectClusterizer->getFarObjects());
+		if(tasks.empty() || nullkiller->getScanDepth() != ScanDepth::SMALL)
+			captureObjects(nullkiller->objectClusterizer->getFarObjects());
 	}
 
 	return tasks;

+ 1 - 1
AI/Nullkiller/Behaviors/CaptureObjectsBehavior.h

@@ -67,7 +67,7 @@ namespace Goals
 
 		bool operator==(const CaptureObjectsBehavior & other) const override;
 
-		static Goals::TGoalVec getVisitGoals(const std::vector<AIPath> & paths, const CGObjectInstance * objToVisit = nullptr);
+		static Goals::TGoalVec getVisitGoals(const std::vector<AIPath> & paths, Nullkiller * nullkiller, const CGObjectInstance * objToVisit = nullptr);
 
 	private:
 		bool objectMatchesFilter(const CGObjectInstance * obj) const;

+ 1 - 1
AI/Nullkiller/Behaviors/ClusterBehavior.cpp

@@ -100,7 +100,7 @@ Goals::TGoalVec ClusterBehavior::decomposeCluster(std::shared_ptr<ObjectCluster>
 	logAi->trace("Decompose unlock paths");
 #endif
 
-	auto unlockTasks = CaptureObjectsBehavior::getVisitGoals(blockerPaths);
+	auto unlockTasks = CaptureObjectsBehavior::getVisitGoals(blockerPaths, ai->nullkiller.get());
 
 	for(int i = 0; i < paths.size(); i++)
 	{

+ 1 - 1
AI/Nullkiller/Behaviors/DefenceBehavior.cpp

@@ -88,7 +88,7 @@ void handleCounterAttack(
 		&& (threat.danger == maximumDanger.danger || threat.turn < maximumDanger.turn))
 	{
 		auto heroCapturingPaths = ai->nullkiller->pathfinder->getPathInfo(threat.hero->visitablePos());
-		auto goals = CaptureObjectsBehavior::getVisitGoals(heroCapturingPaths, threat.hero.get());
+		auto goals = CaptureObjectsBehavior::getVisitGoals(heroCapturingPaths, ai->nullkiller.get(), threat.hero.get());
 
 		for(int i = 0; i < heroCapturingPaths.size(); i++)
 		{

+ 1 - 1
AI/Nullkiller/Behaviors/GatherArmyBehavior.cpp

@@ -232,7 +232,7 @@ Goals::TGoalVec GatherArmyBehavior::upgradeArmy(const CGTownInstance * upgrader)
 #endif
 	
 	auto paths = ai->nullkiller->pathfinder->getPathInfo(pos, ai->nullkiller->settings->isObjectGraphAllowed());
-	auto goals = CaptureObjectsBehavior::getVisitGoals(paths);
+	auto goals = CaptureObjectsBehavior::getVisitGoals(paths, ai->nullkiller.get());
 
 	std::vector<std::shared_ptr<ExecuteHeroChain>> waysToVisitObj;
 

+ 19 - 20
AI/Nullkiller/Engine/Nullkiller.cpp

@@ -188,7 +188,10 @@ void Nullkiller::updateAiState(int pass, bool fast)
 
 		if(settings->isObjectGraphAllowed())
 		{
-			pathfinder->updateGraphs(activeHeroes);
+			pathfinder->updateGraphs(
+				activeHeroes,
+				scanDepth == ScanDepth::SMALL ? 255 : 10,
+				scanDepth == ScanDepth::ALL_FULL ? 255 : 3);
 		}
 
 		boost::this_thread::interruption_point();
@@ -303,8 +306,7 @@ void Nullkiller::makeTurn()
 
 		// TODO: better to check turn distance here instead of priority
 		if((heroRole != HeroRole::MAIN || bestTask->priority < SMALL_SCAN_MIN_PRIORITY)
-			&& scanDepth == ScanDepth::MAIN_FULL
-			&& !settings->isObjectGraphAllowed())
+			&& scanDepth == ScanDepth::MAIN_FULL)
 		{
 			useHeroChain = false;
 			scanDepth = ScanDepth::SMALL;
@@ -317,25 +319,22 @@ void Nullkiller::makeTurn()
 
 		if(bestTask->priority < MIN_PRIORITY)
 		{
-			if(!settings->isObjectGraphAllowed())
+			auto heroes = cb->getHeroesInfo();
+			auto hasMp = vstd::contains_if(heroes, [](const CGHeroInstance * h) -> bool
+				{
+					return h->movementPointsRemaining() > 100;
+				});
+
+			if(hasMp && scanDepth != ScanDepth::ALL_FULL)
 			{
-				auto heroes = cb->getHeroesInfo();
-				auto hasMp = vstd::contains_if(heroes, [](const CGHeroInstance * h) -> bool
-					{
-						return h->movementPointsRemaining() > 100;
-					});
+				logAi->trace(
+					"Goal %s has too low priority %f so increasing scan depth to full.",
+					taskDescription,
+					bestTask->priority);
 
-				if(hasMp && scanDepth != ScanDepth::ALL_FULL)
-				{
-					logAi->trace(
-						"Goal %s has too low priority %f so increasing scan depth to full.",
-						taskDescription,
-						bestTask->priority);
-
-					scanDepth = ScanDepth::ALL_FULL;
-					useHeroChain = false;
-					continue;
-				}
+				scanDepth = ScanDepth::ALL_FULL;
+				useHeroChain = false;
+				continue;
 			}
 
 			logAi->trace("Goal %s has too low priority. It is not worth doing it. Ending turn.", taskDescription);

+ 2 - 2
AI/Nullkiller/Goals/CompleteQuest.cpp

@@ -112,7 +112,7 @@ TGoalVec CompleteQuest::tryCompleteQuest() const
 		return !q.quest->checkQuest(path.targetHero);
 	});
 	
-	return CaptureObjectsBehavior::getVisitGoals(paths, q.obj);
+	return CaptureObjectsBehavior::getVisitGoals(paths, ai->nullkiller.get(), q.obj);
 }
 
 TGoalVec CompleteQuest::missionArt() const
@@ -154,7 +154,7 @@ TGoalVec CompleteQuest::missionArmy() const
 		return !CQuest::checkMissionArmy(q.quest, path.heroArmy);
 	});
 
-	return CaptureObjectsBehavior::getVisitGoals(paths, q.obj);
+	return CaptureObjectsBehavior::getVisitGoals(paths, ai->nullkiller.get(), q.obj);
 }
 
 TGoalVec CompleteQuest::missionIncreasePrimaryStat() const

+ 11 - 3
AI/Nullkiller/Pathfinding/AIPathfinder.cpp

@@ -138,7 +138,10 @@ void AIPathfinder::updatePaths(const std::map<const CGHeroInstance *, HeroRole>
 	logAi->trace("Recalculated paths in %ld", timeElapsed(start));
 }
 
-void AIPathfinder::updateGraphs(const std::map<const CGHeroInstance *, HeroRole> & heroes)
+void AIPathfinder::updateGraphs(
+	const std::map<const CGHeroInstance *, HeroRole> & heroes,
+	uint8_t mainScanDepth,
+	uint8_t scoutScanDepth)
 {
 	auto start = std::chrono::high_resolution_clock::now();
 	std::vector<const CGHeroInstance *> heroesVector;
@@ -154,10 +157,15 @@ void AIPathfinder::updateGraphs(const std::map<const CGHeroInstance *, HeroRole>
 		}
 	}
 
-	tbb::parallel_for(tbb::blocked_range<size_t>(0, heroesVector.size()), [this, &heroesVector](const tbb::blocked_range<size_t> & r)
+	tbb::parallel_for(tbb::blocked_range<size_t>(0, heroesVector.size()), [this, &heroesVector, &heroes, mainScanDepth, scoutScanDepth](const tbb::blocked_range<size_t> & r)
 		{
 			for(auto i = r.begin(); i != r.end(); i++)
-				heroGraphs.at(heroesVector[i]->id)->calculatePaths(heroesVector[i], ai);
+			{
+				auto role = heroes.at(heroesVector[i]);
+				auto scanLimit = role == HeroRole::MAIN ? mainScanDepth : scoutScanDepth;
+
+				heroGraphs.at(heroesVector[i]->id)->calculatePaths(heroesVector[i], ai, scanLimit);
+			}
 		});
 
 	if(NKAI_GRAPH_TRACE_LEVEL >= 1)

+ 1 - 1
AI/Nullkiller/Pathfinding/AIPathfinder.h

@@ -47,7 +47,7 @@ public:
 	void calculatePathInfo(std::vector<AIPath> & paths, const int3 & tile, bool includeGraph = false) const;
 	bool isTileAccessible(const HeroPtr & hero, const int3 & tile) const;
 	void updatePaths(const std::map<const CGHeroInstance *, HeroRole> & heroes, PathfinderSettings pathfinderSettings);
-	void updateGraphs(const std::map<const CGHeroInstance *, HeroRole> & heroes);
+	void updateGraphs(const std::map<const CGHeroInstance *, HeroRole> & heroes, uint8_t mainScanDepth, uint8_t scoutScanDepth);
 	void calculateQuickPathsWithBlocker(std::vector<AIPath> & result, const std::vector<const CGHeroInstance *> & heroes, const int3 & tile);
 	void init();
 

+ 7 - 2
AI/Nullkiller/Pathfinding/ObjectGraph.cpp

@@ -562,7 +562,7 @@ std::shared_ptr<SpecialAction> getCompositeAction(
 	return std::make_shared<CompositeAction>(actionsArray);
 }
 
-void GraphPaths::calculatePaths(const CGHeroInstance * targetHero, const Nullkiller * ai)
+void GraphPaths::calculatePaths(const CGHeroInstance * targetHero, const Nullkiller * ai, uint8_t scanDepth)
 {
 	graph.copyFrom(*ai->baseGraph);
 	graph.connectHeroes(ai);
@@ -611,7 +611,7 @@ void GraphPaths::calculatePaths(const CGHeroInstance * targetHero, const Nullkil
 
 		node.isInQueue = false;
 
-		graph.iterateConnections(pos.coord, [this, ai, &pos, &node, &transitionAction, &pq](int3 target, const ObjectLink & o)
+		graph.iterateConnections(pos.coord, [this, ai, &pos, &node, &transitionAction, &pq, scanDepth](int3 target, const ObjectLink & o)
 			{
 				auto compositeAction = getCompositeAction(ai, o.specialAction, transitionAction);
 				auto targetNodeType = o.danger || compositeAction ? GrapthPathNodeType::BATTLE : pos.nodeType;
@@ -620,6 +620,11 @@ void GraphPaths::calculatePaths(const CGHeroInstance * targetHero, const Nullkil
 
 				if(targetNode.tryUpdate(pos, node, o))
 				{
+					if(targetNode.cost > scanDepth)
+					{
+						return;
+					}
+
 					targetNode.specialAction = compositeAction;
 
 					auto targetGraphNode = graph.getNode(target);

+ 1 - 1
AI/Nullkiller/Pathfinding/ObjectGraph.h

@@ -187,7 +187,7 @@ class GraphPaths
 
 public:
 	GraphPaths();
-	void calculatePaths(const CGHeroInstance * targetHero, const Nullkiller * ai);
+	void calculatePaths(const CGHeroInstance * targetHero, const Nullkiller * ai, uint8_t scanDepth);
 	void addChainInfo(std::vector<AIPath> & paths, int3 tile, const CGHeroInstance * hero, const Nullkiller * ai) const;
 	void quickAddChainInfoWithBlocker(std::vector<AIPath> & paths, int3 tile, const CGHeroInstance * hero, const Nullkiller * ai) const;
 	void dumpToLog() const;