Pārlūkot izejas kodu

Battle AI 2 turns attacks recalculation

Andrii Danylchenko 3 gadi atpakaļ
vecāks
revīzija
033a585e4b

+ 18 - 4
AI/BattleAI/AttackPossibility.cpp

@@ -123,10 +123,11 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 
 			for(int i = 0; i < totalAttacks; i++)
 			{
-				si64 damageDealt, damageReceived;
+				int64_t damageDealt, damageReceived, enemyDpsReduce, ourDpsReduce;
 
 				TDmgRange retaliation(0, 0);
 				auto attackDmg = getCbc()->battleEstimateDamage(ap.attack, &retaliation);
+				TDmgRange enemyDamageBeforeAttack = getCbc()->battleEstimateDamage(BattleAttackInfo(u, attacker, u->canShoot()));
 
 				vstd::amin(attackDmg.first, defenderState->getAvailableHealth());
 				vstd::amin(attackDmg.second, defenderState->getAvailableHealth());
@@ -137,29 +138,42 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 				damageDealt = (attackDmg.first + attackDmg.second) / 2;
 				ap.attackerState->afterAttack(attackInfo.shooting, false);
 
+				auto enemiesKilled = damageDealt / u->MaxHealth() + (damageDealt % u->MaxHealth() >= u->getFirstHPleft() ? 1 : 0);
+				auto enemyDps = (enemyDamageBeforeAttack.first + enemyDamageBeforeAttack.second) / 2;
+
+				enemyDpsReduce = enemiesKilled
+					? (int64_t)(enemyDps * enemiesKilled / (double)u->getCount())
+					: (int64_t)(enemyDps / (double)u->getCount() * damageDealt / u->getFirstHPleft());
+
 				//FIXME: use ranged retaliation
 				damageReceived = 0;
+				ourDpsReduce = 0;
+
 				if (!attackInfo.shooting && defenderState->ableToRetaliate() && !counterAttacksBlocked)
 				{
 					damageReceived = (retaliation.first + retaliation.second) / 2;
 					defenderState->afterAttack(attackInfo.shooting, true);
+
+					auto ourUnitsKilled = damageReceived / attacker->MaxHealth() + (damageReceived % attacker->MaxHealth() >= attacker->getFirstHPleft() ? 1 : 0);
+
+					ourDpsReduce = (int64_t)(damageDealt * ourUnitsKilled / (double)attacker->getCount());
 				}
 
 				bool isEnemy = state->battleMatchOwner(attacker, u);
 
 				// this includes enemy units as well as attacker units under enemy's mind control
 				if(isEnemy)
-					ap.damageDealt += damageDealt;
+					ap.damageDealt += enemyDpsReduce;
 
 				// damaging attacker's units (even those under enemy's mind control) is considered friendly fire
 				if(attackerSide == u->unitSide())
-					ap.collateralDamage += damageDealt;
+					ap.collateralDamage += enemyDpsReduce;
 
 				if(u->unitId() == defender->unitId() || 
 					(!attackInfo.shooting && CStack::isMeleeAttackPossible(u, attacker, hex)))
 				{
 					//FIXME: handle RANGED_RETALIATION ?
-					ap.damageReceived += damageReceived;
+					ap.damageReceived += ourDpsReduce;
 				}
 
 				ap.attackerState->damage(damageReceived);

+ 42 - 7
AI/BattleAI/BattleAI.cpp

@@ -9,6 +9,7 @@
  */
 #include "StdInc.h"
 #include "BattleAI.h"
+#include "BattleExchangeVariant.h"
 
 #include "StackWithBonuses.h"
 #include "EnemyInfo.h"
@@ -92,7 +93,7 @@ void CBattleAI::init(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCa
 
 BattleAction CBattleAI::activeStack( const CStack * stack )
 {
-	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName())	;
+	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
 	setCbc(cb); //TODO: make solid sure that AIs always use their callbacks (need to take care of event handlers too)
 	try
 	{
@@ -157,17 +158,33 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 		}
 
 		HypotheticBattle hb(env.get(), cb);
-
+		int turn = 0;
+		
 		PotentialTargets targets(stack, &hb);
+		BattleExchangeEvaluator scoreEvaluator(cb, env);
 
 		if(!targets.possibleAttacks.empty())
 		{
-			AttackPossibility bestAttack = targets.bestAction();
+			logAi->trace("Evaluating attack for %s", stack->getDescription());
 
+			auto evaluationResult = scoreEvaluator.findBestTarget(stack, targets, hb);
+			auto & bestAttack = evaluationResult.bestAttack;
+			
 			//TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc.
 			if(bestSpellcast.is_initialized() && bestSpellcast->value > bestAttack.damageDiff())
 				return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
-			else if(bestAttack.attack.shooting)
+
+			if(evaluationResult.wait)
+			{
+				return BattleAction::makeWait(stack);
+			}
+
+			if(evaluationResult.score == EvaluationResult::INEFFECTIVE_SCORE)
+			{
+				return BattleAction::makeDefend(stack);
+			}
+			
+			if(bestAttack.attack.shooting)
 			{
 				auto &target = bestAttack;
 				logAi->debug("BattleAI: %s -> %s x %d, shot, from %d curpos %d dist %d speed %d: %lld %lld %lld",
@@ -285,13 +302,31 @@ BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<Battl
 		return BattleAction::makeDefend(stack);
 	}
 
+	BattleExchangeEvaluator scoreEvaluator(cb, env);
+	HypotheticBattle hb(env.get(), cb);
+
+	scoreEvaluator.updateReachabilityMap(hb);
+
 	if(stack->hasBonusOfType(Bonus::FLYING))
 	{
+		std::set<BattleHex> moatHexes;
+
+		if(hb.battleGetSiegeLevel() >= BuildingID::CITADEL)
+		{
+			auto townMoat = hb.getDefendedTown()->town->moatHexes;
+
+			moatHexes = std::set<BattleHex>(townMoat.begin(), townMoat.end());
+		}
 		// Flying stack doesn't go hex by hex, so we can't backtrack using predecessors.
 		// We just check all available hexes and pick the one closest to the target.
 		auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int
 		{
-			return BattleHex::getDistance(bestNeighbor, hex);
+			auto distance = BattleHex::getDistance(bestNeighbor, hex);
+
+			if(vstd::contains(moatHexes, hex))
+				distance += 100;
+
+			return scoreEvaluator.checkPositionBlocksOurStacks(hb, stack, hex) ? 100 + distance : distance;
 		});
 
 		return BattleAction::makeMove(stack, *nearestAvailableHex);
@@ -303,11 +338,11 @@ BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<Battl
 		{
 			if(!currentDest.isValid())
 			{
-				logAi->error("CBattleAI::goTowards: internal error");
 				return BattleAction::makeDefend(stack);
 			}
 
-			if(vstd::contains(avHexes, currentDest))
+			if(vstd::contains(avHexes, currentDest)
+				&& !scoreEvaluator.checkPositionBlocksOurStacks(hb, stack, currentDest))
 				return BattleAction::makeMove(stack, currentDest);
 
 			currentDest = reachability.predecessors[currentDest];

+ 504 - 0
AI/BattleAI/BattleExchangeVariant.cpp

@@ -0,0 +1,504 @@
+/*
+ * BattleAI.cpp, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#include "StdInc.h"
+#include "BattleExchangeVariant.h"
+#include "../../lib/CStack.h"
+
+int64_t BattleExchangeVariant::trackAttack(const AttackPossibility & ap, HypotheticBattle * state)
+{
+	auto affectedUnits = ap.affectedUnits;
+
+	affectedUnits.push_back(ap.attackerState);
+
+	for(auto affectedUnit : affectedUnits)
+	{
+		auto unitToUpdate = state->getForUpdate(affectedUnit->unitId());
+
+		unitToUpdate->health = affectedUnit->health;
+		unitToUpdate->shots = affectedUnit->shots;
+		unitToUpdate->counterAttacks = affectedUnit->counterAttacks;
+		unitToUpdate->movedThisRound = affectedUnit->movedThisRound;
+	}
+
+	auto attackValue = ap.attackValue();
+
+	dpsScore += attackValue;
+
+	logAi->trace(
+		"%s -> %s, ap attack, %s, dps: %d, score: %d",
+		ap.attack.attacker->getDescription(),
+		ap.attack.defender->getDescription(),
+		ap.attack.shooting ? "shot" : "mellee",
+		ap.damageDealt,
+		attackValue);
+
+	return attackValue;
+}
+
+int64_t BattleExchangeVariant::trackAttack(
+	std::shared_ptr<StackWithBonuses> attacker,
+	std::shared_ptr<StackWithBonuses> defender,
+	bool shooting,
+	bool isOurAttack,
+	std::shared_ptr<CBattleInfoCallback> cb,
+	bool evaluateOnly)
+{
+	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
+	static const auto selectorBlocksRetaliation = Selector::type()(Bonus::BLOCKS_RETALIATION);
+	const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
+
+	TDmgRange retalitation;
+	BattleAttackInfo bai(attacker.get(), defender.get(), shooting);
+	auto attack = cb->battleEstimateDamage(bai, &retalitation);
+	int64_t attackDamage = (attack.first + attack.second) / 2;
+	int64_t defenderDpsReduce = calculateDpsReduce(attacker.get(), defender.get(), attackDamage, cb);
+	int64_t attackerDpsReduce = 0;
+
+	if(!evaluateOnly)
+	{
+		logAi->trace(
+			"%s -> %s, normal attack, %s, dps: %d, %d",
+			attacker->getDescription(),
+			defender->getDescription(),
+			shooting ? "shot" : "mellee",
+			attackDamage,
+			defenderDpsReduce);
+
+		if(isOurAttack)
+		{
+			dpsScore += defenderDpsReduce;
+			attackerValue[attacker->unitId()].value += defenderDpsReduce;
+		}
+		else
+			dpsScore -= defenderDpsReduce;
+
+		defender->damage(attackDamage);
+		attacker->afterAttack(shooting, false);
+	}
+
+	if(defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
+	{
+		if(retalitation.second != 0)
+		{
+			auto retalitationDamage = (retalitation.first + retalitation.second) / 2;
+			attackerDpsReduce = calculateDpsReduce(defender.get(), attacker.get(), retalitationDamage, cb);
+
+			if(!evaluateOnly)
+			{
+				logAi->trace(
+					"%s -> %s, retalitation, dps: %d, %d",
+					defender->getDescription(),
+					attacker->getDescription(),
+					retalitationDamage,
+					attackerDpsReduce);
+
+				if(isOurAttack)
+				{
+					dpsScore -= attackerDpsReduce;
+					attackerValue[attacker->unitId()].isRetalitated = true;
+				}
+				else
+				{
+					dpsScore += attackerDpsReduce;
+					attackerValue[defender->unitId()].value += attackerDpsReduce;
+				}
+
+				attacker->damage(retalitationDamage);
+				defender->afterAttack(false, true);
+			}
+		}
+	}
+
+	auto score = defenderDpsReduce - attackerDpsReduce;
+
+	if(!score)
+	{
+		logAi->trace("Zero %d %d", defenderDpsReduce, attackerDpsReduce);
+	}
+
+	return score;
+}
+
+int64_t BattleExchangeVariant::calculateDpsReduce(
+	const battle::Unit * attacker,
+	const battle::Unit * defender,
+	uint64_t damageDealt,
+	std::shared_ptr<CBattleInfoCallback> cb) const
+{
+	vstd::amin(damageDealt, defender->getAvailableHealth());
+
+	auto enemyDamageBeforeAttack = cb->battleEstimateDamage(BattleAttackInfo(defender, attacker, defender->canShoot()));
+	auto enemiesKilled = damageDealt / defender->MaxHealth() + (damageDealt % defender->MaxHealth() >= defender->getFirstHPleft() ? 1 : 0);
+	auto enemyDps = (enemyDamageBeforeAttack.first + enemyDamageBeforeAttack.second) / 2;
+
+	return (int64_t)(enemyDps * enemiesKilled / (double)defender->getCount()
+		+ enemyDps / (double)defender->getCount() * ((damageDealt - defender->getFirstHPleft()) % defender->MaxHealth()) / defender->MaxHealth());
+};
+
+EvaluationResult BattleExchangeEvaluator::findBestTarget(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb)
+{
+	EvaluationResult result(targets.bestAction());
+
+	updateReachabilityMap(hb);
+
+	for(auto & ap : targets.possibleAttacks)
+	{
+		int64_t score = calculateExchange(ap);
+
+		if(score > result.score)
+		{
+			result.score = score;
+			result.bestAttack = ap;
+		}
+	}
+
+	if(!activeStack->waited())
+	{
+		logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
+
+		hb.getForUpdate(activeStack->unitId())->waiting = true;
+		hb.getForUpdate(activeStack->unitId())->waitedThisTurn = true;
+
+		updateReachabilityMap(hb);
+
+		for(auto & ap : targets.possibleAttacks)
+		{
+			int64_t score = calculateExchange(ap);
+
+			if(score > result.score)
+			{
+				result.score = score;
+				result.bestAttack = ap;
+				result.wait = true;
+			}
+		}
+	}
+
+	return result;
+}
+
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
+	const AttackPossibility & ap)
+{
+	auto hexes = ap.attack.defender->getHexes();
+
+	if(!ap.attack.shooting) hexes.push_back(ap.from);
+
+	std::vector<const battle::Unit *> exchangeUnits;
+	std::vector<const battle::Unit *> allReachableUnits;
+
+	for(auto hex : hexes)
+	{
+		vstd::concatenate(allReachableUnits, reachabilityMap[hex]);
+	}
+
+	vstd::removeDuplicates(allReachableUnits);
+
+	if(allReachableUnits.size() < 2)
+	{
+		logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
+
+		return exchangeUnits;
+	}
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		for(auto unit : turnOrder[turn])
+		{
+			if(vstd::contains(allReachableUnits, unit))
+				exchangeUnits.push_back(unit);
+		}
+	}
+
+	return exchangeUnits;
+}
+
+int64_t BattleExchangeEvaluator::calculateExchange(const AttackPossibility & ap)
+{
+	logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest : ap.from);
+
+	std::vector<const battle::Unit *> ourStacks;
+	std::vector<const battle::Unit *> enemyStacks;
+
+	enemyStacks.push_back(ap.attack.defender);
+
+	std::vector<const battle::Unit *> exchangeUnits = getExchangeUnits(ap);
+
+	if(exchangeUnits.empty())
+	{
+		return 0;
+	}
+
+	HypotheticBattle exchangeBattle(env.get(), cb);
+	BattleExchangeVariant v;
+	auto melleeAttackers = ourStacks;
+
+	vstd::removeDuplicates(melleeAttackers);
+	vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
+		{
+			return !cb->battleCanShoot(u);
+		});
+
+	for(auto unit : exchangeUnits)
+	{
+		bool isOur = cb->battleMatchOwner(ap.attack.attacker, unit, true);
+		auto & attackerQueue = isOur ? ourStacks : enemyStacks;
+		auto & oppositeQueue = isOur ? enemyStacks : ourStacks;
+
+		if(!vstd::contains(attackerQueue, unit))
+		{
+			attackerQueue.push_back(unit);
+		}
+	}
+
+	bool canUseAp = true;
+
+	for(auto activeUnit : exchangeUnits)
+	{
+		bool isOur = cb->battleMatchOwner(ap.attack.attacker, activeUnit, true);
+		battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
+		battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
+
+		auto attacker = exchangeBattle.getForUpdate(activeUnit->unitId());
+
+		if(!attacker->alive() || oppositeQueue.empty())
+		{
+			logAi->trace(
+				"Attacker [%s] dead(%d) or opposite queue empty(%d)",
+				attacker->getDescription(),
+				attacker->alive() ? 0 : 1,
+				oppositeQueue.size());
+
+			continue;
+		}
+
+		auto targetUnit = ap.attack.defender;
+
+		if(!isOur || !exchangeBattle.getForUpdate(targetUnit->unitId())->alive())
+		{
+			targetUnit = *vstd::maxElementByFun(oppositeQueue, [&](const battle::Unit * u) -> int64_t
+				{
+					auto stackWithBonuses = exchangeBattle.getForUpdate(u->unitId());
+					auto score = v.trackAttack(
+						attacker,
+						stackWithBonuses,
+						exchangeBattle.battleCanShoot(stackWithBonuses.get()),
+						isOur,
+						cb,
+						true);
+
+					logAi->trace("Best target selector %s->%s score = %d", attacker->getDescription(), u->getDescription(), score);
+
+					return score;
+				});
+		}
+
+		auto defender = exchangeBattle.getForUpdate(targetUnit->unitId());
+		auto shooting = cb->battleCanShoot(attacker.get());
+		const int totalAttacks = attacker->getTotalAttacks(shooting);
+
+		if(canUseAp && activeUnit == ap.attack.attacker && targetUnit == ap.attack.defender)
+		{
+			v.trackAttack(ap, &exchangeBattle);
+		}
+		else
+		{
+			for(int i = 0; i < totalAttacks; i++)
+			{
+				v.trackAttack(attacker, defender, shooting, isOur, cb);
+
+				if(!attacker->alive() || !defender->alive())
+					break;
+			}
+		}
+
+		canUseAp = false;
+
+		vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
+			{
+				return !exchangeBattle.getForUpdate(u->unitId())->alive();
+			});
+
+		vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
+			{
+				return !exchangeBattle.getForUpdate(u->unitId())->alive();
+			});
+	}
+
+	v.adjustPositions(melleeAttackers, ap, reachabilityMap);
+
+	logAi->trace("Exchange score: %ld", v.getScore());
+
+	return v.getScore();
+}
+
+void BattleExchangeVariant::adjustPositions(
+	std::vector<const battle::Unit*> attackers,
+	const AttackPossibility & ap,
+	std::map<BattleHex, battle::Units> & reachabilityMap)
+{
+	auto hexes = ap.attack.defender->getSurroundingHexes();
+
+	boost::sort(attackers, [&](const battle::Unit * u1, const battle::Unit * u2) -> bool
+		{
+			if(attackerValue[u1->unitId()].isRetalitated && !attackerValue[u2->unitId()].isRetalitated)
+				return true;
+
+			if(attackerValue[u2->unitId()].isRetalitated && !attackerValue[u1->unitId()].isRetalitated)
+				return false;
+
+			return attackerValue[u1->unitId()].value > attackerValue[u2->unitId()].value;
+		});
+
+	if(!ap.attack.shooting)
+	{
+		vstd::erase_if_present(hexes, ap.from);
+		vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos));
+	}
+
+	int64_t notRealizedDps = 0;
+
+	for(auto unit : attackers)
+	{
+		if(unit->unitId() == ap.attack.attacker->unitId())
+			continue;
+
+		if(!vstd::contains_if(hexes, [&](BattleHex h) -> bool
+			{
+				return vstd::contains(reachabilityMap[h], unit);
+			}))
+		{
+			notRealizedDps += attackerValue[unit->unitId()].value;
+			continue;
+		}
+
+		auto desiredPosition = vstd::minElementByFun(hexes, [&](BattleHex h) -> int64_t
+			{
+				auto score = vstd::contains(reachabilityMap[h], unit)
+					? reachabilityMap[h].size()
+					: 1000;
+
+				if(unit->doubleWide())
+				{
+					auto backHex = unit->occupiedHex(h);
+
+					if(vstd::contains(hexes, backHex))
+						score += reachabilityMap[backHex].size();
+				}
+
+				return score;
+			});
+
+		hexes.erase(desiredPosition);
+	}
+
+	if(notRealizedDps > ap.attackValue() && notRealizedDps > attackerValue[ap.attack.attacker->unitId()].value)
+	{
+		dpsScore = EvaluationResult::INEFFECTIVE_SCORE;
+	}
+}
+
+void BattleExchangeEvaluator::updateReachabilityMap(HypotheticBattle & hb)
+{
+	turnOrder.clear();
+
+	hb.battleGetTurnOrder(turnOrder, 1000, 2);
+	reachabilityMap.clear();
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		auto & turnQueue = turnOrder[turn];
+		HypotheticBattle turnBattle(env.get(), cb);
+
+		for(const battle::Unit * unit : turnQueue)
+		{
+			auto unitReachability = turnBattle.getReachability(unit);
+
+			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			{
+				bool reachable = unitReachability.distances[hex] <= unit->Speed(turn);
+
+				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				{
+					const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
+
+					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					{
+						for(BattleHex neighbor : hex.neighbouringTiles())
+						{
+							reachable = unitReachability.distances[neighbor] <= unit->Speed(turn);
+
+							if(reachable) break;
+						}
+					}
+				}
+
+				if(reachable)
+				{
+					reachabilityMap[hex].push_back(unit);
+				}
+			}
+		}
+	}
+}
+
+bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
+{
+	int blockingScore = 0;
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		auto & turnQueue = turnOrder[turn];
+		HypotheticBattle turnBattle(env.get(), cb);
+
+		auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
+		unitToUpdate->setPosition(position);
+
+		for(const battle::Unit * unit : turnQueue)
+		{
+			if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
+				continue;
+
+			auto unitReachability = turnBattle.getReachability(unit);
+
+			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			{
+				bool enemyUnit = false;
+				bool reachable = unitReachability.distances[hex] <= unit->Speed(turn);
+
+				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				{
+					const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
+
+					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					{
+						enemyUnit = true;
+
+						for(BattleHex neighbor : hex.neighbouringTiles())
+						{
+							reachable = unitReachability.distances[neighbor] <= unit->Speed(turn);
+
+							if(reachable) break;
+						}
+					}
+				}
+
+				if(!reachable && vstd::contains(reachabilityMap[hex], unit))
+				{
+					blockingScore += enemyUnit ? 100 : 1;
+				}
+			}
+		}
+	}
+
+	logAi->trace("Position %d, blocking score %d", position.hex, blockingScore);
+
+	return blockingScore > 50;
+}

+ 100 - 0
AI/BattleAI/BattleExchangeVariant.h

@@ -0,0 +1,100 @@
+/*
+ * BattleExchangeVariant.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#pragma once
+
+#include "../../lib/AI_Base.h"
+#include "../../lib/battle/ReachabilityInfo.h"
+#include "PotentialTargets.h"
+#include "StackWithBonuses.h"
+
+struct AttackerValue
+{
+	int64_t value;
+	bool isRetalitated;
+	BattleHex position;
+
+	AttackerValue()
+	{
+		value = 0;
+		isRetalitated = false;
+	}
+};
+
+class BattleExchangeVariant
+{
+public:
+	BattleExchangeVariant()
+		:dpsScore(0), attackerValue()
+	{
+	}
+
+	int64_t trackAttack(const AttackPossibility & ap, HypotheticBattle * state);
+
+	int64_t trackAttack(
+		std::shared_ptr<StackWithBonuses> attacker,
+		std::shared_ptr<StackWithBonuses> defender,
+		bool shooting,
+		bool isOurAttack,
+		std::shared_ptr<CBattleInfoCallback> cb,
+		bool evaluateOnly = false);
+
+	int64_t getScore() const { return dpsScore; }
+
+	void adjustPositions(
+		std::vector<const battle::Unit *> attackers,
+		const AttackPossibility & ap,
+		std::map<BattleHex, battle::Units> & reachabilityMap);
+
+private:
+	int64_t dpsScore;
+	std::map<uint32_t, AttackerValue> attackerValue;
+
+	int64_t calculateDpsReduce(
+		const battle::Unit * attacker,
+		const battle::Unit * defender,
+		uint64_t damageDealt,
+		std::shared_ptr<CBattleInfoCallback> cb) const;
+};
+
+struct EvaluationResult
+{
+	static const int64_t INEFFECTIVE_SCORE = -1000000;
+
+	AttackPossibility bestAttack;
+	bool wait;
+	int64_t score;
+	bool defend;
+
+	EvaluationResult(AttackPossibility & ap)
+		:wait(false), score(0), bestAttack(ap), defend(false)
+	{
+	}
+};
+
+class BattleExchangeEvaluator
+{
+private:
+	std::shared_ptr<CBattleInfoCallback> cb;
+	std::shared_ptr<Environment> env;
+	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
+	std::vector<battle::Units> turnOrder;
+
+public:
+	BattleExchangeEvaluator(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env)
+		:cb(cb), reachabilityMap(), env(env), turnOrder()
+	{
+	}
+
+	EvaluationResult findBestTarget(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb);
+	int64_t calculateExchange(const AttackPossibility & ap);
+	void updateReachabilityMap(HypotheticBattle & hb);
+	std::vector<const battle::Unit *> getExchangeUnits(const AttackPossibility & ap);
+	bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
+};

+ 2 - 0
AI/BattleAI/CMakeLists.txt

@@ -10,6 +10,7 @@ set(battleAI_SRCS
 		PotentialTargets.cpp
 		StackWithBonuses.cpp
 		ThreatMap.cpp
+		BattleExchangeVariant.cpp
 )
 
 set(battleAI_HEADERS
@@ -23,6 +24,7 @@ set(battleAI_HEADERS
 		PossibleSpellcast.h
 		StackWithBonuses.h
 		ThreatMap.h
+		BattleExchangeVariant.h
 )
 
 assign_source_group(${battleAI_SRCS} ${battleAI_HEADERS})

+ 15 - 1
AI/BattleAI/StackWithBonuses.cpp

@@ -199,6 +199,21 @@ void StackWithBonuses::removeUnitBonus(const CSelector & selector)
 	vstd::erase_if(bonusesToUpdate, [&](const Bonus & b){return selector(&b);});
 }
 
+std::string StackWithBonuses::getDescription() const
+{
+	std::ostringstream oss;
+	oss << unitOwner().getStr();
+	oss << " battle stack [" << unitId() << "]: " << getCount() << " of ";
+	if(type)
+		oss << type->namePl;
+	else
+		oss << "[UNDEFINED TYPE]";
+
+	oss << " from slot " << slot;
+
+	return oss.str();
+}
+
 void StackWithBonuses::spendMana(ServerCallback * server, const int spellCost) const
 {
 	//TODO: evaluate cast use
@@ -284,7 +299,6 @@ int32_t HypotheticBattle::getActiveStackID() const
 void HypotheticBattle::nextRound(int32_t roundNr)
 {
 	//TODO:HypotheticBattle::nextRound
-
 	for(auto unit : battleAliveUnits())
 	{
 		auto forUpdate = getForUpdate(unit->unitId());

+ 1 - 0
AI/BattleAI/StackWithBonuses.h

@@ -85,6 +85,7 @@ public:
 	void removeUnitBonus(const CSelector & selector);
 
 	void spendMana(ServerCallback * server, const int spellCost) const override;
+	std::string getDescription() const override;
 
 private:
 	const IBonusBearer * origBearer;