Browse Source

BattleAI: bigger reachability map

Andrii Danylchenko 2 years ago
parent
commit
870fbd50e3

+ 2 - 6
AI/BattleAI/AttackPossibility.cpp

@@ -157,13 +157,9 @@ float AttackPossibility::calculateDamageReduce(
 	auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state);
 	auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0);
 	auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount();
-	
-	// lets use cached maxHealth here instead of getAvailableHealth
-	auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth;
-	auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
-	auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
+	auto lastUnitKillValue = (damageDealt % maxHealth) / (double)maxHealth;;
 
-	return damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY);
+	return damagePerEnemy * (enemiesKilled + lastUnitKillValue * HEALTH_BOUNTY);
 }
 
 int64_t AttackPossibility::evaluateBlockedShootersDmg(

+ 2 - 2
AI/BattleAI/BattleEvaluator.cpp

@@ -149,7 +149,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 				bestAttack.attack.attacker->speed(0, true),
 				bestAttack.defenderDamageReduce,
 				bestAttack.attackerDamageReduce,
-				bestAttack.attackValue()
+				score
 			);
 
 			if (moveTarget.scorePerTurn <= score)
@@ -580,7 +580,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 				}
 				else
 				{
-					ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, *targets, innerCache, state);
+					ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
 				}
 
 				for(auto unit : allUnits)

+ 104 - 39
AI/BattleAI/BattleExchangeVariant.cpp

@@ -116,6 +116,10 @@ float BattleExchangeVariant::trackAttack(
 		}
 	}
 
+#if BATTLE_TRACE_LEVEL >= 1
+	logAi->trace("ap shooters blocking: %lld", ap.shootersBlockedDmg);
+#endif
+
 	attackValue += ap.shootersBlockedDmg;
 	dpsScore.enemyDamageReduce += ap.shootersBlockedDmg;
 	attacker->afterAttack(ap.attack.shooting, false);
@@ -233,13 +237,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 		for(auto & ap : targets.possibleAttacks)
 		{
-			float score = evaluateExchange(ap, targets, damageCache, hbWaited);
+			float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
 
 			if(score > result.score)
 			{
 				result.score = score;
 				result.bestAttack = ap;
 				result.wait = true;
+
+#if BATTLE_TRACE_LEVEL >= 1
+				logAi->trace("New high score %2f", result.score);
+#endif
 			}
 		}
 	}
@@ -258,13 +266,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 	for(auto & ap : targets.possibleAttacks)
 	{
-		float score = evaluateExchange(ap, targets, damageCache, hb);
+		float score = evaluateExchange(ap, 0, targets, damageCache, hb);
 
-		if(score >= result.score)
+		if(score > result.score || score == result.score && result.wait)
 		{
 			result.score = score;
 			result.bestAttack = ap;
 			result.wait = false;
+
+#if BATTLE_TRACE_LEVEL >= 1
+			logAi->trace("New high score %2f", result.score);
+#endif
 		}
 	}
 
@@ -312,7 +324,10 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 		auto hexes = closestStack->getSurroundingHexes();
 		auto enemySpeed = closestStack->speed();
 		auto speedRatio = speed / static_cast<float>(enemySpeed);
-		auto penalty = speedRatio > 1 ? 1 : speedRatio;
+		auto multiplier = speedRatio > 1 ? 1 : speedRatio;
+
+		if(enemy->canShoot())
+			multiplier *= 1.5f;
 
 		for(auto hex : hexes)
 		{
@@ -323,7 +338,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 			attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
 
 			auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
-			auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(penalty / turnsToRich), score.ourDamageReduce);
+			auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(multiplier / turnsToRich), score.ourDamageReduce);
 
 			if(result.scorePerTurn < scoreValue(scorePerTurn))
 			{
@@ -371,12 +386,13 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(cons
 
 ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 	const AttackPossibility & ap,
+	uint8_t turn,
 	PotentialTargets & targets,
 	std::shared_ptr<HypotheticBattle> hb)
 {
 	ReachabilityData result;
 
-	auto hexes = ap.attack.defender->getHexes();
+	auto hexes = ap.attack.defender->getSurroundingHexes();
 
 	if(!ap.attack.shooting) hexes.push_back(ap.from);
 
@@ -384,7 +400,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 
 	for(auto hex : hexes)
 	{
-		vstd::concatenate(allReachableUnits, reachabilityMap[hex]);
+		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
 	}
 
 	vstd::removeDuplicates(allReachableUnits);
@@ -460,17 +476,33 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 
 float BattleExchangeEvaluator::evaluateExchange(
 	const AttackPossibility & ap,
+	uint8_t turn,
 	PotentialTargets & targets,
 	DamageCache & damageCache,
 	std::shared_ptr<HypotheticBattle> hb)
 {
-	BattleScore score = calculateExchange(ap, targets, damageCache, hb);
+	if(ap.from.hex == 127)
+	{
+		logAi->trace("x");
+	}
+
+	BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
+
+#if BATTLE_TRACE_LEVEL >= 1
+	logAi->trace(
+		"calculateExchange score +%2f -%2fx%2f = %2f",
+		score.enemyDamageReduce,
+		score.ourDamageReduce,
+		getNegativeEffectMultiplier(),
+		scoreValue(score));
+#endif
 
 	return scoreValue(score);
 }
 
 BattleScore BattleExchangeEvaluator::calculateExchange(
 	const AttackPossibility & ap,
+	uint8_t turn,
 	PotentialTargets & targets,
 	DamageCache & damageCache,
 	std::shared_ptr<HypotheticBattle> hb)
@@ -492,8 +524,6 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 	if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
 		enemyStacks.push_back(ap.attack.defender);
 
-	vstd::amin(turn, reachabilityMapByTurns.size() - 1);
-
 	ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb);
 
 	if(exchangeUnits.units.empty())
@@ -511,10 +541,15 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 
 		bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
 		auto & attackerQueue = isOur ? ourStacks : enemyStacks;
+		auto u = exchangeBattle->getForUpdate(unit->unitId());
 
-		if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit))
+		if(u->alive() && !vstd::contains(attackerQueue, unit))
 		{
 			attackerQueue.push_back(unit);
+
+#if BATTLE_TRACE_LEVEL
+			logAi->trace("Exchanging: %s", u->getDescription());
+#endif
 		}
 	}
 
@@ -562,7 +597,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 					true);
 
 #if BATTLE_TRACE_LEVEL>=1
-				logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), u->getDescription(), score);
+				logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
 #endif
 
 				return score;
@@ -645,6 +680,13 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 
 	// avoid blocking path for stronger stack by weaker stack
 	// the method checks if all stacks can be placed around enemy
+	std::map<BattleHex, battle::Units> reachabilityMap;
+
+	auto hexes = ap.attack.defender->getSurroundingHexes();
+
+	for(auto hex : hexes)
+		reachabilityMap[hex] = getOneTurnReachableUnits(turn, hex);
+
 	v.adjustPositions(melleeAttackers, ap, reachabilityMap);
 
 #if BATTLE_TRACE_LEVEL>=1
@@ -736,18 +778,38 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
 	return false;
 }
 
-void BattleExchangeEvaluator::updateReachabilityMap(	std::shared_ptr<HypotheticBattle> hb)
+void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
 {
 	const int TURN_DEPTH = 2;
 
 	turnOrder.clear();
-	
+
 	hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
-	reachabilityMap.clear();
 
-	for(int turn = 0; turn < turnOrder.size(); turn++)
+	for(auto turn : turnOrder)
 	{
-		auto & turnQueue = turnOrder[turn];
+		for(auto u : turn)
+		{
+			if(!vstd::contains(reachabilityCache, u->unitId()))
+			{
+				reachabilityCache[u->unitId()] = hb->getReachability(u);
+			}
+		}
+	}
+
+	for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+	{
+		reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
+	}
+}
+
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
+{
+	std::vector<const battle::Unit *> result;
+
+	for(int i = 0; i < turnOrder.size(); i++, turn++)
+	{
+		auto & turnQueue = turnOrder[i];
 		HypotheticBattle turnBattle(env.get(), cb);
 
 		for(const battle::Unit * unit : turnQueue)
@@ -755,46 +817,49 @@ void BattleExchangeEvaluator::updateReachabilityMap(	std::shared_ptr<HypotheticB
 			if(unit->isTurret())
 				continue;
 
-			auto unitSpeed = unit->speed(turn);
-
 			if(turnBattle.battleCanShoot(unit))
 			{
-				for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
-				{
-					reachabilityMap[hex].push_back(unit);
-				}
+				result.push_back(unit);
 
 				continue;
 			}
 
-			auto unitReachability = turnBattle.getReachability(unit);
+			auto unitSpeed = unit->speed(turn);
+			auto radius = unitSpeed * (turn + 1);
 
-			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			ReachabilityInfo unitReachability = vstd::getOrCompute(
+				reachabilityCache,
+				unit->unitId(),
+				[&](ReachabilityInfo & data)
+				{
+					data = turnBattle.getReachability(unit);
+				});
+
+			bool reachable = unitReachability.distances[hex] <= radius;
+
+			if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
 			{
-				bool reachable = unitReachability.distances[hex] <= unitSpeed;
+				const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
 
-				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
 				{
-					const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
-
-					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					for(BattleHex neighbor : hex.neighbouringTiles())
 					{
-						for(BattleHex neighbor : hex.neighbouringTiles())
-						{
-							reachable = unitReachability.distances[neighbor] <= unitSpeed;
+						reachable = unitReachability.distances[neighbor] <= radius;
 
-							if(reachable) break;
-						}
+						if(reachable) break;
 					}
 				}
+			}
 
-				if(reachable)
-				{
-					reachabilityMap[hex].push_back(unit);
-				}
+			if(reachable)
+			{
+				result.push_back(unit);
 			}
 		}
 	}
+
+	return result;
 }
 
 // avoid blocking path for stronger stack by weaker stack

+ 5 - 1
AI/BattleAI/BattleExchangeVariant.h

@@ -132,6 +132,7 @@ class BattleExchangeEvaluator
 private:
 	std::shared_ptr<CBattleInfoCallback> cb;
 	std::shared_ptr<Environment> env;
+	std::map<uint32_t, ReachabilityInfo> reachabilityCache;
 	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
 	std::vector<battle::Units> turnOrder;
 	float negativeEffectMultiplier;
@@ -140,6 +141,7 @@ private:
 
 	BattleScore calculateExchange(
 		const AttackPossibility & ap,
+		uint8_t turn,
 		PotentialTargets & targets,
 		DamageCache & damageCache,
 		std::shared_ptr<HypotheticBattle> hb);
@@ -151,7 +153,7 @@ public:
 		std::shared_ptr<CBattleInfoCallback> cb,
 		std::shared_ptr<Environment> env,
 		float strengthRatio): cb(cb), env(env) {
-		negativeEffectMultiplier = std::sqrt(strengthRatio);
+		negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio;
 	}
 
 	EvaluationResult findBestTarget(
@@ -162,10 +164,12 @@ public:
 
 	float evaluateExchange(
 		const AttackPossibility & ap,
+		uint8_t turn,
 		PotentialTargets & targets,
 		DamageCache & damageCache,
 		std::shared_ptr<HypotheticBattle> hb);
 
+	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
 	void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
 
 	ReachabilityData getExchangeUnits(