Explorar el Código

Fixed potential thread races in Battle AI

Ivan Savenko hace 1 año
padre
commit
b7efa6c8cc

+ 2 - 5
AI/BattleAI/BattleEvaluator.cpp

@@ -602,10 +602,10 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 					ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
 				}
 
-				for(auto unit : allUnits)
+				for(const auto & unit : allUnits)
 				{
 					auto newHealth = unit->getAvailableHealth();
-					auto oldHealth = healthOfStack[unit->unitId()];
+					auto oldHealth = vstd::find_or(healthOfStack, unit->unitId(), 0); // old health value may not exist for newly summoned units
 
 					if(oldHealth != newHealth)
 					{
@@ -732,6 +732,3 @@ void BattleEvaluator::print(const std::string & text) const
 {
 	logAi->trace("%s Battle AI[%p]: %s", playerID.toString(), this, text);
 }
-
-
-

+ 11 - 14
AI/BattleAI/BattleExchangeVariant.cpp

@@ -390,7 +390,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 	const AttackPossibility & ap,
 	uint8_t turn,
 	PotentialTargets & targets,
-	std::shared_ptr<HypotheticBattle> hb)
+	std::shared_ptr<HypotheticBattle> hb) const
 {
 	ReachabilityData result;
 
@@ -402,7 +402,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 
 	for(auto hex : hexes)
 	{
-		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
+		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
 	}
 
 	vstd::removeDuplicates(allReachableUnits);
@@ -481,7 +481,7 @@ float BattleExchangeEvaluator::evaluateExchange(
 	uint8_t turn,
 	PotentialTargets & targets,
 	DamageCache & damageCache,
-	std::shared_ptr<HypotheticBattle> hb)
+	std::shared_ptr<HypotheticBattle> hb) const
 {
 	BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
 
@@ -502,7 +502,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 	uint8_t turn,
 	PotentialTargets & targets,
 	DamageCache & damageCache,
-	std::shared_ptr<HypotheticBattle> hb)
+	std::shared_ptr<HypotheticBattle> hb) const
 {
 #if BATTLE_TRACE_LEVEL>=1
 	logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
@@ -613,7 +613,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 			}
 			else
 			{
-				auto reachable = exchangeBattle->battleGetUnitsIf([&](const battle::Unit * u) -> bool
+				auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
 					{
 						if(u->unitSide() == attacker->unitSide())
 							return false;
@@ -621,7 +621,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 						if(!exchangeBattle->getForUpdate(u->unitId())->alive())
 							return false;
 
-						return vstd::contains_if(reachabilityMap[u->getPosition()], [&](const battle::Unit * other) -> bool
+						return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
 							{
 								return attacker->unitId() == other->unitId();
 							});
@@ -732,7 +732,7 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBa
 	}
 }
 
-std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
 {
 	std::vector<const battle::Unit *> result;
 
@@ -756,13 +756,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
 			auto unitSpeed = unit->getMovementRange(turn);
 			auto radius = unitSpeed * (turn + 1);
 
-			ReachabilityInfo unitReachability = vstd::getOrCompute(
-				reachabilityCache,
-				unit->unitId(),
-				[&](ReachabilityInfo & data)
-				{
-					data = turnBattle.getReachability(unit);
-				});
+			auto reachabilityIter = reachabilityCache.find(unit->unitId());
+			assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
+
+			ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
 
 			bool reachable = unitReachability.distances[hex] <= radius;
 

+ 4 - 4
AI/BattleAI/BattleExchangeVariant.h

@@ -139,7 +139,7 @@ private:
 		uint8_t turn,
 		PotentialTargets & targets,
 		DamageCache & damageCache,
-		std::shared_ptr<HypotheticBattle> hb);
+		std::shared_ptr<HypotheticBattle> hb) const;
 
 	bool canBeHitThisTurn(const AttackPossibility & ap);
 
@@ -162,16 +162,16 @@ public:
 		uint8_t turn,
 		PotentialTargets & targets,
 		DamageCache & damageCache,
-		std::shared_ptr<HypotheticBattle> hb);
+		std::shared_ptr<HypotheticBattle> hb) const;
 
-	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
+	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const;
 	void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
 
 	ReachabilityData getExchangeUnits(
 		const AttackPossibility & ap,
 		uint8_t turn,
 		PotentialTargets & targets,
-		std::shared_ptr<HypotheticBattle> hb);
+		std::shared_ptr<HypotheticBattle> hb) const;
 
 	bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
 

+ 9 - 14
Global.h

@@ -348,6 +348,15 @@ namespace vstd
 		return std::find(c.begin(),c.end(),i);
 	}
 
+	// returns existing value from map, or default value if key does not exists
+	template <typename Map>
+	const typename Map::mapped_type & find_or(const Map& m, const typename Map::key_type& key, const typename Map::mapped_type& defaultValue) {
+		auto it = m.find(key);
+		if (it == m.end())
+			return defaultValue;
+		return it->second;
+	}
+
 	//returns first key that maps to given value if present, returns success via found if provided
 	template <typename Key, typename T>
 	Key findKey(const std::map<Key, T> & map, const T & value, bool * found = nullptr)
@@ -684,20 +693,6 @@ namespace vstd
 		return false;
 	}
 
-	template<class M, class Key, class F>
-	typename M::mapped_type & getOrCompute(M & m, const Key & k, F f)
-	{
-		typedef typename M::mapped_type V;
-
-		std::pair<typename M::iterator, bool> r = m.insert(typename M::value_type(k, V()));
-		V & v = r.first->second;
-
-		if(r.second)
-			f(v);
-
-		return v;
-	}
-
 	//c++20 feature
 	template<typename Arithmetic, typename Floating>
 	Arithmetic lerp(const Arithmetic & a, const Arithmetic & b, const Floating & f)