Procházet zdrojové kódy

Optimize computation of reachability map

Ivan Savenko před 9 měsíci
rodič
revize
2d5b5d94e7

+ 11 - 8
AI/BattleAI/BattleExchangeVariant.cpp

@@ -11,6 +11,7 @@
 #include "BattleExchangeVariant.h"
 #include "BattleEvaluator.h"
 #include "../../lib/CStack.h"
+#include "tbb/parallel_for.h"
 
 AttackerValue::AttackerValue()
 	: value(0),
@@ -518,14 +519,14 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 	
 	for(auto hex : hexes)
 	{
-		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
+		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex));
 	}
 
 	if(!ap.attack.attacker->isTurret())
 	{
 		for(auto hex : ap.attack.attacker->getHexes())
 		{
-			auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex);
+			auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex);
 			for(auto unit : unitsReachingAttacker)
 			{
 				if(unit->unitSide() != ap.attack.attacker->unitSide())
@@ -799,7 +800,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 							if(!u->getPosition().isValid())
 								return false; // e.g. tower shooters
 
-							return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
+							return vstd::contains_if(reachabilityMap.at(u->getPosition().toInt()), [&attacker](const battle::Unit * other) -> bool
 								{
 									return attacker->unitId() == other->unitId();
 								});
@@ -886,7 +887,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
 {
 	for(auto pos : ap.attack.attacker->getSurroundingHexes())
 	{
-		for(auto u : reachabilityMap[pos])
+		for(auto u : reachabilityMap.at(pos.toInt()))
 		{
 			if(u->unitSide() != ap.attack.attacker->unitSide())
 			{
@@ -916,10 +917,12 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBa
 			}
 		}
 	}
-	for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); ++hex)
+
+	tbb::parallel_for(tbb::blocked_range<size_t>(0, reachabilityMap.size()), [&](const tbb::blocked_range<size_t> & r)
 	{
-		reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
-	}
+		for(auto i = r.begin(); i != r.end(); i++)
+			reachabilityMap[i] = getOneTurnReachableUnits(0, BattleHex(i));
+	});
 }
 
 std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
@@ -1029,7 +1032,7 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb
 					}
 				}
 
-				if(!reachable && std::count(reachabilityMap[hex].begin(), reachabilityMap[hex].end(), unit) > 1)
+				if(!reachable && std::count(reachabilityMap[hex.toInt()].begin(), reachabilityMap[hex.toInt()].end(), unit) > 1)
 				{
 					blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
 				}

+ 1 - 1
AI/BattleAI/BattleExchangeVariant.h

@@ -129,7 +129,7 @@ private:
 	std::shared_ptr<CBattleInfoCallback> cb;
 	std::shared_ptr<Environment> env;
 	std::map<uint32_t, ReachabilityInfo> reachabilityCache;
-	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
+	std::array<std::vector<const battle::Unit *>, GameConstants::BFIELD_SIZE> reachabilityMap;
 	std::vector<battle::Units> turnOrder;
 	float negativeEffectMultiplier;
 	int simulationTurnsCount;