Browse Source

Try to implement lazy evaluation for reachability map

Ivan Savenko 10 months ago
parent
commit
797b62fd46
2 changed files with 65 additions and 27 deletions
  1. 46 24
      AI/BattleAI/BattleExchangeVariant.cpp
  2. 19 3
      AI/BattleAI/BattleExchangeVariant.h

+ 46 - 24
AI/BattleAI/BattleExchangeVariant.cpp

@@ -519,14 +519,14 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 	
 	for(auto hex : hexes)
 	{
-		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex));
+		vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex));
 	}
 
 	if(!ap.attack.attacker->isTurret())
 	{
 		for(auto hex : ap.attack.attacker->getHexes())
 		{
-			auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex);
+			auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex);
 			for(auto unit : unitsReachingAttacker)
 			{
 				if(unit->unitSide() != ap.attack.attacker->unitSide())
@@ -800,7 +800,9 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 							if(!u->getPosition().isValid())
 								return false; // e.g. tower shooters
 
-							return vstd::contains_if(reachabilityMap.at(u->getPosition().toInt()), [&attacker](const battle::Unit * other) -> bool
+							const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition());
+
+							return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool
 								{
 									return attacker->unitId() == other->unitId();
 								});
@@ -887,7 +889,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
 {
 	for(auto pos : ap.attack.attacker->getSurroundingHexes())
 	{
-		for(auto u : reachabilityMap.at(pos.toInt()))
+		for(auto u : getOneTurnReachableUnits(0, pos))
 		{
 			if(u->unitSide() != ap.attack.attacker->unitSide())
 			{
@@ -899,35 +901,48 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
 	return false;
 }
 
-void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
+void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb)
 {
-	const int TURN_DEPTH = 2;
-
-	turnOrder.clear();
-
-	hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
-
 	for(auto turn : turnOrder)
 	{
 		for(auto u : turn)
 		{
-			if(!vstd::contains(reachabilityCache, u->unitId()))
+			if(!vstd::contains(unitReachabilityMap, u->unitId()))
 			{
-				reachabilityCache[u->unitId()] = hb->getReachability(u);
+				unitReachabilityMap[u->unitId()] = hb->getReachability(u);
 			}
 		}
 	}
 
-	tbb::parallel_for(tbb::blocked_range<size_t>(0, reachabilityMap.size()), [&](const tbb::blocked_range<size_t> & r)
+	hexReachabilityPerTurn.clear();
+}
+
+void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
+{
+	const int TURN_DEPTH = 2;
+
+	turnOrder.clear();
+
+	hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
+	reachabilityMap.update(turnOrder, hb);
+}
+
+const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
+{
+	auto & turnData = hexReachabilityPerTurn[turn];
+
+	if (!turnData.isValid[hex.toInt()])
 	{
-		for(auto i = r.begin(); i != r.end(); i++)
-			reachabilityMap[i] = getOneTurnReachableUnits(0, BattleHex(i));
-	});
+		turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
+		turnData.isValid.set(hex.toInt());
+	}
+
+	return turnData.hexes[hex.toInt()];
 }
 
-std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
+battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
 {
-	std::vector<const battle::Unit *> result;
+	battle::Units result;
 
 	for(int i = 0; i < turnOrder.size(); i++, turn++)
 	{
@@ -949,10 +964,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
 			auto unitSpeed = unit->getMovementRange(turn);
 			auto radius = unitSpeed * (turn + 1);
 
-			auto reachabilityIter = reachabilityCache.find(unit->unitId());
-			assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
+			auto reachabilityIter = unitReachabilityMap.find(unit->unitId());
+			assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call?
 
-			ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
+			ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
 
 			bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
 
@@ -981,6 +996,11 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
 	return result;
 }
 
+const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
+{
+	return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
+}
+
 // avoid blocking path for stronger stack by weaker stack
 bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
 {
@@ -1032,9 +1052,11 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb
 					}
 				}
 
-				if(!reachable && std::count(reachabilityMap[hex.toInt()].begin(), reachabilityMap[hex.toInt()].end(), unit) > 1)
+				if(!reachable)
 				{
-					blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
+					auto reachableUnits = getOneTurnReachableUnits(0, hex);
+					if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1)
+						blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
 				}
 			}
 		}

+ 19 - 3
AI/BattleAI/BattleExchangeVariant.h

@@ -123,13 +123,29 @@ struct ReachabilityData
 	std::set<uint32_t> enemyUnitsReachingAttacker;
 };
 
+class ReachabilityMapCache
+{
+	struct PerTurnData{
+		std::bitset<GameConstants::BFIELD_SIZE> isValid;
+		std::array<battle::Units, GameConstants::BFIELD_SIZE> hexes;
+	};
+
+	std::map<uint32_t, ReachabilityInfo> unitReachabilityMap; // unit ID -> reachability
+	std::map<uint32_t, PerTurnData> hexReachabilityPerTurn;
+
+	//const ReachabilityInfo & update();
+	battle::Units computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex);
+public:
+	const battle::Units & getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex);
+	void update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb);
+};
+
 class BattleExchangeEvaluator
 {
 private:
 	std::shared_ptr<CBattleInfoCallback> cb;
 	std::shared_ptr<Environment> env;
-	std::map<uint32_t, ReachabilityInfo> reachabilityCache;
-	std::array<std::vector<const battle::Unit *>, GameConstants::BFIELD_SIZE> reachabilityMap;
+	mutable ReachabilityMapCache reachabilityMap;
 	std::vector<battle::Units> turnOrder;
 	float negativeEffectMultiplier;
 	int simulationTurnsCount;
@@ -169,7 +185,7 @@ public:
 		DamageCache & damageCache,
 		std::shared_ptr<HypotheticBattle> hb) const;
 
-	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
+	const battle::Units & getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
 	void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
 
 	ReachabilityData getExchangeUnits(