Browse Source

Merge pull request #1067 from vcmi/battle-ai-improvements

Battle ai improvements
Andrii Danylchenko 3 years ago
parent
commit
afc0cc1c15

+ 56 - 22
AI/BattleAI/AttackPossibility.cpp

@@ -13,6 +13,11 @@
                               // Eventually only IBattleInfoCallback and battle::Unit should be used, 
                               // CUnitState should be private and CStack should be removed completely
 
+uint64_t averageDmg(const TDmgRange & range)
+{
+	return (range.first + range.second) / 2;
+}
+
 AttackPossibility::AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack)
 	: from(from), dest(dest), attack(attack)
 {
@@ -20,7 +25,7 @@ AttackPossibility::AttackPossibility(BattleHex from, BattleHex dest, const Battl
 
 int64_t AttackPossibility::damageDiff() const
 {
-	return damageDealt - damageReceived - collateralDamage + shootersBlockedDmg;
+	return defenderDamageReduce - attackerDamageReduce - collateralDamageReduce + shootersBlockedDmg;
 }
 
 int64_t AttackPossibility::attackValue() const
@@ -28,7 +33,31 @@ int64_t AttackPossibility::attackValue() const
 	return damageDiff();
 }
 
-int64_t AttackPossibility::evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state)
+/// <summary>
+/// How enemy damage will be reduced by this attack
+/// Half bounty for kill, half for making damage equal to enemy health
+/// Bounty - the killed creature average damage calculated against attacker
+/// </summary>
+int64_t AttackPossibility::calculateDamageReduce(
+	const battle::Unit * attacker,
+	const battle::Unit * defender,
+	uint64_t damageDealt,
+	const CBattleInfoCallback & cb)
+{
+	const float HEALTH_BOUNTY = 0.5;
+	const float KILL_BOUNTY = 1.0 - HEALTH_BOUNTY;
+
+	vstd::amin(damageDealt, defender->getAvailableHealth());
+
+	auto enemyDamageBeforeAttack = cb.battleEstimateDamage(BattleAttackInfo(defender, attacker, defender->canShoot()));
+	auto enemiesKilled = damageDealt / defender->MaxHealth() + (damageDealt % defender->MaxHealth() >= defender->getFirstHPleft() ? 1 : 0);
+	auto enemyDamage = averageDmg(enemyDamageBeforeAttack);
+	auto damagePerEnemy = enemyDamage / (double)defender->getCount();
+
+	return (int64_t)(damagePerEnemy * (enemiesKilled * KILL_BOUNTY + damageDealt * HEALTH_BOUNTY / (double)defender->MaxHealth()));
+}
+
+int64_t AttackPossibility::evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle & state)
 {
 	int64_t res = 0;
 
@@ -39,10 +68,10 @@ int64_t AttackPossibility::evaluateBlockedShootersDmg(const BattleAttackInfo & a
 	auto hexes = attacker->getSurroundingHexes(hex);
 	for(BattleHex tile : hexes)
 	{
-		auto st = state->battleGetUnitByPos(tile, true);
-		if(!st || !state->battleMatchOwner(st, attacker))
+		auto st = state.battleGetUnitByPos(tile, true);
+		if(!st || !state.battleMatchOwner(st, attacker))
 			continue;
-		if(!state->battleCanShoot(st))
+		if(!state.battleCanShoot(st))
 			continue;
 
 		BattleAttackInfo rangeAttackInfo(st, attacker, true);
@@ -51,23 +80,23 @@ int64_t AttackPossibility::evaluateBlockedShootersDmg(const BattleAttackInfo & a
 		BattleAttackInfo meleeAttackInfo(st, attacker, false);
 		meleeAttackInfo.defenderPos = hex;
 
-		auto rangeDmg = getCbc()->battleEstimateDamage(rangeAttackInfo);
-		auto meleeDmg = getCbc()->battleEstimateDamage(meleeAttackInfo);
+		auto rangeDmg = state.battleEstimateDamage(rangeAttackInfo);
+		auto meleeDmg = state.battleEstimateDamage(meleeAttackInfo);
 
-		int64_t gain = (rangeDmg.first + rangeDmg.second - meleeDmg.first - meleeDmg.second) / 2 + 1;
+		int64_t gain = averageDmg(rangeDmg) - averageDmg(meleeDmg) + 1;
 		res += gain;
 	}
 
 	return res;
 }
 
-AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state)
+AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle & state)
 {
 	auto attacker = attackInfo.attacker;
 	auto defender = attackInfo.defender;
 	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
 	static const auto selectorBlocksRetaliation = Selector::type()(Bonus::BLOCKS_RETALIATION);
-	const auto attackerSide = getCbc()->playerToSide(getCbc()->battleGetOwner(attacker));
+	const auto attackerSide = state.playerToSide(state.battleGetOwner(attacker));
 	const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
 
 	AttackPossibility bestAp(hex, BattleHex::INVALID, attackInfo);
@@ -95,9 +124,9 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 		std::vector<const battle::Unit*> units;
 
 		if (attackInfo.shooting)
-			units = state->getAttackedBattleUnits(attacker, defHex, true, BattleHex::INVALID);
+			units = state.getAttackedBattleUnits(attacker, defHex, true, BattleHex::INVALID);
 		else
-			units = state->getAttackedBattleUnits(attacker, defHex, false, hex);
+			units = state.getAttackedBattleUnits(attacker, defHex, false, hex);
 
 		// ensure the defender is also affected
 		bool addDefender = true;
@@ -123,10 +152,11 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 
 			for(int i = 0; i < totalAttacks; i++)
 			{
-				si64 damageDealt, damageReceived;
+				int64_t damageDealt, damageReceived, defenderDamageReduce, attackerDamageReduce;
 
 				TDmgRange retaliation(0, 0);
-				auto attackDmg = getCbc()->battleEstimateDamage(ap.attack, &retaliation);
+				auto attackDmg = state.battleEstimateDamage(ap.attack, &retaliation);
+				TDmgRange defenderDamageBeforeAttack = state.battleEstimateDamage(BattleAttackInfo(u, attacker, u->canShoot()));
 
 				vstd::amin(attackDmg.first, defenderState->getAvailableHealth());
 				vstd::amin(attackDmg.second, defenderState->getAvailableHealth());
@@ -134,32 +164,36 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 				vstd::amin(retaliation.first, ap.attackerState->getAvailableHealth());
 				vstd::amin(retaliation.second, ap.attackerState->getAvailableHealth());
 
-				damageDealt = (attackDmg.first + attackDmg.second) / 2;
+				damageDealt = averageDmg(attackDmg);
+				defenderDamageReduce = calculateDamageReduce(attacker, defender, damageDealt, state);
 				ap.attackerState->afterAttack(attackInfo.shooting, false);
 
 				//FIXME: use ranged retaliation
 				damageReceived = 0;
+				attackerDamageReduce = 0;
+
 				if (!attackInfo.shooting && defenderState->ableToRetaliate() && !counterAttacksBlocked)
 				{
-					damageReceived = (retaliation.first + retaliation.second) / 2;
+					damageReceived = averageDmg(retaliation);
+					attackerDamageReduce = calculateDamageReduce(defender, attacker, damageReceived, state);
 					defenderState->afterAttack(attackInfo.shooting, true);
 				}
 
-				bool isEnemy = state->battleMatchOwner(attacker, u);
+				bool isEnemy = state.battleMatchOwner(attacker, u);
 
 				// this includes enemy units as well as attacker units under enemy's mind control
 				if(isEnemy)
-					ap.damageDealt += damageDealt;
+					ap.defenderDamageReduce += defenderDamageReduce;
 
 				// damaging attacker's units (even those under enemy's mind control) is considered friendly fire
 				if(attackerSide == u->unitSide())
-					ap.collateralDamage += damageDealt;
+					ap.collateralDamageReduce += defenderDamageReduce;
 
 				if(u->unitId() == defender->unitId() || 
 					(!attackInfo.shooting && CStack::isMeleeAttackPossible(u, attacker, hex)))
 				{
 					//FIXME: handle RANGED_RETALIATION ?
-					ap.damageReceived += damageReceived;
+					ap.attackerDamageReduce += attackerDamageReduce;
 				}
 
 				ap.attackerState->damage(damageReceived);
@@ -177,11 +211,11 @@ AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInf
 	// check how much damage we gain from blocking enemy shooters on this hex
 	bestAp.shootersBlockedDmg = evaluateBlockedShootersDmg(attackInfo, hex, state);
 
-	logAi->debug("BattleAI best AP: %s -> %s at %d from %d, affects %d units: %lld %lld %lld %lld",
+	logAi->debug("BattleAI best AP: %s -> %s at %d from %d, affects %d units: d:%lld a:%lld c:%lld s:%lld",
 		attackInfo.attacker->unitType()->identifier,
 		attackInfo.defender->unitType()->identifier,
 		(int)bestAp.dest, (int)bestAp.from, (int)bestAp.affectedUnits.size(),
-		bestAp.damageDealt, bestAp.damageReceived, bestAp.collateralDamage, bestAp.shootersBlockedDmg);
+		bestAp.defenderDamageReduce, bestAp.attackerDamageReduce, bestAp.collateralDamageReduce, bestAp.shootersBlockedDmg);
 
 	//TODO other damage related to attack (eg. fire shield and other abilities)
 	return bestAp;

+ 17 - 5
AI/BattleAI/AttackPossibility.h

@@ -13,6 +13,12 @@
 #include "common.h"
 #include "StackWithBonuses.h"
 
+#define BATTLE_TRACE_LEVEL 0
+
+/// <summary>
+/// Evaluate attack value of one particular attack taking into account various effects like
+/// retaliation, 2-hex breath, collateral damage, shooters blocked damage
+/// </summary>
 class AttackPossibility
 {
 public:
@@ -24,9 +30,9 @@ public:
 
 	std::vector<std::shared_ptr<battle::CUnitState>> affectedUnits;
 
-	int64_t damageDealt = 0;
-	int64_t damageReceived = 0; //usually by counter-attack
-	int64_t collateralDamage = 0; // friendly fire (usually by two-hex attacks)
+	int64_t defenderDamageReduce = 0;
+	int64_t attackerDamageReduce = 0; //usually by counter-attack
+	int64_t collateralDamageReduce = 0; // friendly fire (usually by two-hex attacks)
 	int64_t shootersBlockedDmg = 0;
 
 	AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack_);
@@ -34,8 +40,14 @@ public:
 	int64_t damageDiff() const;
 	int64_t attackValue() const;
 
-	static AttackPossibility evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state);
+	static AttackPossibility evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle & state);
+
+	static int64_t calculateDamageReduce(
+		const battle::Unit * attacker,
+		const battle::Unit * defender,
+		uint64_t damageDealt,
+		const CBattleInfoCallback & cb);
 
 private:
-	static int64_t evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state);
+	static int64_t evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle & state);
 };

+ 103 - 64
AI/BattleAI/BattleAI.cpp

@@ -9,6 +9,7 @@
  */
 #include "StdInc.h"
 #include "BattleAI.h"
+#include "BattleExchangeVariant.h"
 
 #include "StackWithBonuses.h"
 #include "EnemyInfo.h"
@@ -92,8 +93,11 @@ void CBattleAI::init(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCa
 
 BattleAction CBattleAI::activeStack( const CStack * stack )
 {
-	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName())	;
+	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
+
+	BattleAction result = BattleAction::makeDefend(stack);
 	setCbc(cb); //TODO: make solid sure that AIs always use their callbacks (need to take care of event handlers too)
+
 	try
 	{
 		if(stack->type->idNumber == CreatureID::CATAPULT)
@@ -157,72 +161,86 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 		}
 
 		HypotheticBattle hb(env.get(), cb);
+		
+		PotentialTargets targets(stack, hb);
+		BattleExchangeEvaluator scoreEvaluator(cb, env);
+		auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, targets, hb);
 
-		PotentialTargets targets(stack, &hb);
+		int64_t score = EvaluationResult::INEFFECTIVE_SCORE;
 
 		if(!targets.possibleAttacks.empty())
 		{
-			AttackPossibility bestAttack = targets.bestAction();
+#if BATTLE_TRACE_LEVEL>=1
+			logAi->trace("Evaluating attack for %s", stack->getDescription());
+#endif
+
+			auto evaluationResult = scoreEvaluator.findBestTarget(stack, targets, hb);
+			auto & bestAttack = evaluationResult.bestAttack;
 
 			//TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc.
 			if(bestSpellcast.is_initialized() && bestSpellcast->value > bestAttack.damageDiff())
-				return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
-			else if(bestAttack.attack.shooting)
 			{
-				auto &target = bestAttack;
-				logAi->debug("BattleAI: %s -> %s x %d, shot, from %d curpos %d dist %d speed %d: %lld %lld %lld",
-					target.attackerState->unitType()->identifier,
-					target.affectedUnits[0]->unitType()->identifier,
-					(int)target.affectedUnits.size(), (int)target.from, (int)bestAttack.attack.attacker->getPosition().hex,
-					bestAttack.attack.chargedFields, bestAttack.attack.attacker->Speed(0, true),
-					target.damageDealt, target.damageReceived, target.attackValue()
-				);
-
-				return BattleAction::makeShotAttack(stack, bestAttack.attack.defender);
+				// return because spellcast value is damage dealt and score is dps reduce
+				return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
 			}
-			else
+
+			if(evaluationResult.score > score)
 			{
-				auto &target = bestAttack;
-				logAi->debug("BattleAI: %s -> %s x %d, mellee, from %d curpos %d dist %d speed %d: %lld %lld %lld",
-					target.attackerState->unitType()->identifier,
-					target.affectedUnits[0]->unitType()->identifier,
-					(int)target.affectedUnits.size(), (int)target.from, (int)bestAttack.attack.attacker->getPosition().hex,
+				auto & target = bestAttack;
+				score = evaluationResult.score;
+				std::string action;
+
+				if(evaluationResult.wait)
+				{
+					result = BattleAction::makeWait(stack);
+					action = "wait";
+				}
+				else if(bestAttack.attack.shooting)
+				{
+
+					result = BattleAction::makeShotAttack(stack, bestAttack.attack.defender);
+					action = "shot";
+				}
+				else
+				{
+					result = BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from);
+					action = "melee";
+				}
+
+				logAi->debug("BattleAI: %s -> %s x %d, %s, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld",
+					bestAttack.attackerState->unitType()->identifier,
+					bestAttack.affectedUnits[0]->unitType()->identifier,
+					(int)bestAttack.affectedUnits[0]->getCount(), action, (int)bestAttack.from, (int)bestAttack.attack.attacker->getPosition().hex,
 					bestAttack.attack.chargedFields, bestAttack.attack.attacker->Speed(0, true),
-					target.damageDealt, target.damageReceived, target.attackValue()
+					bestAttack.defenderDamageReduce, bestAttack.attackerDamageReduce, bestAttack.attackValue()
 				);
-
-				return BattleAction::makeMeleeAttack(stack,	bestAttack.attack.defender->getPosition(), bestAttack.from);
-		}
+			}
 		}
 		else if(bestSpellcast.is_initialized())
 		{
 			return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
 		}
-		else
+
+			//ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code.
+		if(moveTarget.score > score)
 		{
+			score = moveTarget.score;
+
 			if(stack->waited())
 			{
-				//ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code.
-				auto dists = cb->getReachability(stack);
-				if(!targets.unreachableEnemies.empty())
-				{
-					auto closestEnemy = vstd::minElementByFun(targets.unreachableEnemies, [&](const battle::Unit * enemy) -> int
-					{
-						return dists.distToNearestNeighbour(stack, enemy);
-					});
-
-					if(dists.distToNearestNeighbour(stack, *closestEnemy) < GameConstants::BFIELD_SIZE)
-					{
-						return goTowardsNearest(stack, (*closestEnemy)->getAttackableHexes(stack));
-					}
-				}
+				result = goTowardsNearest(stack, moveTarget.positions);
 			}
 			else
 			{
-				return BattleAction::makeWait(stack);
+				result = BattleAction::makeWait(stack);
 			}
 		}
 
+		if(score > EvaluationResult::INEFFECTIVE_SCORE)
+		{
+			return result;
+		}
+
 		if(!stack->hasBonusOfType(Bonus::FLYING)
 			&& stack->unitSide() == BattleSide::ATTACKER
 			&& cb->battleGetSiegeLevel() >= CGTownInstance::CITADEL)
@@ -235,7 +253,7 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 					return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT));
 				else
 					return goTowardsNearest(stack, brokenWallMoat);
-	}
+			}
 		}
 	}
 	catch(boost::thread_interrupted &)
@@ -247,7 +265,7 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 		logAi->error("Exception occurred in %s %s",__FUNCTION__, e.what());
 	}
 
-	return BattleAction::makeDefend(stack);
+	return result;
 }
 
 BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes) const
@@ -272,10 +290,10 @@ BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<Battl
 
 		if(stack->coversPos(hex))
 		{
-		logAi->warn("Warning: already standing on neighbouring tile!");
-		//We shouldn't even be here...
-		return BattleAction::makeDefend(stack);
-	}
+			logAi->warn("Warning: already standing on neighbouring tile!");
+			//We shouldn't even be here...
+			return BattleAction::makeDefend(stack);
+		}
 	}
 
 	BattleHex bestNeighbor = hexes.front();
@@ -285,13 +303,34 @@ BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<Battl
 		return BattleAction::makeDefend(stack);
 	}
 
+	BattleExchangeEvaluator scoreEvaluator(cb, env);
+	HypotheticBattle hb(env.get(), cb);
+
+	scoreEvaluator.updateReachabilityMap(hb);
+
 	if(stack->hasBonusOfType(Bonus::FLYING))
 	{
+		std::set<BattleHex> moatHexes;
+
+		if(hb.battleGetSiegeLevel() >= BuildingID::CITADEL)
+		{
+			auto townMoat = hb.getDefendedTown()->town->moatHexes;
+
+			moatHexes = std::set<BattleHex>(townMoat.begin(), townMoat.end());
+		}
 		// Flying stack doesn't go hex by hex, so we can't backtrack using predecessors.
 		// We just check all available hexes and pick the one closest to the target.
 		auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int
 		{
-			return BattleHex::getDistance(bestNeighbor, hex);
+			const int MOAT_PENALTY = 100; // avoid landing on moat
+			const int BLOCKED_STACK_PENALTY = 100; // avoid landing on moat
+
+			auto distance = BattleHex::getDistance(bestNeighbor, hex);
+
+			if(vstd::contains(moatHexes, hex))
+				distance += MOAT_PENALTY;
+
+			return scoreEvaluator.checkPositionBlocksOurStacks(hb, stack, hex) ? BLOCKED_STACK_PENALTY + distance : distance;
 		});
 
 		return BattleAction::makeMove(stack, *nearestAvailableHex);
@@ -303,11 +342,11 @@ BattleAction CBattleAI::goTowardsNearest(const CStack * stack, std::vector<Battl
 		{
 			if(!currentDest.isValid())
 			{
-				logAi->error("CBattleAI::goTowards: internal error");
 				return BattleAction::makeDefend(stack);
 			}
 
-			if(vstd::contains(avHexes, currentDest))
+			if(vstd::contains(avHexes, currentDest)
+				&& !scoreEvaluator.checkPositionBlocksOurStacks(hb, stack, currentDest))
 				return BattleAction::makeMove(stack, currentDest);
 
 			currentDest = reachability.predecessors[currentDest];
@@ -407,7 +446,7 @@ void CBattleAI::attemptCastingSpell()
 
 	using ValueMap = PossibleSpellcast::ValueMap;
 
-	auto evaluateQueue = [&](ValueMap & values, const std::vector<battle::Units> & queue, HypotheticBattle * state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool
+	auto evaluateQueue = [&](ValueMap & values, const std::vector<battle::Units> & queue, HypotheticBattle & state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool
 	{
 		bool firstRound = true;
 		bool enemyHadTurn = false;
@@ -418,7 +457,7 @@ void CBattleAI::attemptCastingSpell()
 		for(auto & round : queue)
 		{
 			if(!firstRound)
-				state->nextRound(0);//todo: set actual value?
+				state.nextRound(0);//todo: set actual value?
 			for(auto unit : round)
 			{
 				if(!vstd::contains(values, unit->unitId()))
@@ -427,11 +466,11 @@ void CBattleAI::attemptCastingSpell()
 				if(!unit->alive())
 					continue;
 
-				if(state->battleGetOwner(unit) != playerID)
+				if(state.battleGetOwner(unit) != playerID)
 				{
 					enemyHadTurn = true;
 
-					if(!firstRound || state->battleCastSpells(unit->unitSide()) == 0)
+					if(!firstRound || state.battleCastSpells(unit->unitSide()) == 0)
 					{
 						//enemy could counter our spell at this point
 						//anyway, we do not know what enemy will do
@@ -445,7 +484,7 @@ void CBattleAI::attemptCastingSpell()
 					ourTurnSpan++;
 				}
 
-				state->nextTurn(unit->unitId());
+				state.nextTurn(unit->unitId());
 
 				PotentialTargets pt(unit, state);
 
@@ -453,22 +492,22 @@ void CBattleAI::attemptCastingSpell()
 				{
 					AttackPossibility ap = pt.bestAction();
 
-					auto swb = state->getForUpdate(unit->unitId());
+					auto swb = state.getForUpdate(unit->unitId());
 					*swb = *ap.attackerState;
 
-					if(ap.damageDealt > 0)
+					if(ap.defenderDamageReduce > 0)
 						swb->removeUnitBonus(Bonus::UntilAttack);
-					if(ap.damageReceived > 0)
+					if(ap.attackerDamageReduce > 0)
 						swb->removeUnitBonus(Bonus::UntilBeingAttacked);
 
 					for(auto affected : ap.affectedUnits)
 					{
-						swb = state->getForUpdate(affected->unitId());
+						swb = state.getForUpdate(affected->unitId());
 						*swb = *affected;
 
-						if(ap.damageDealt > 0)
+						if(ap.defenderDamageReduce > 0)
 							swb->removeUnitBonus(Bonus::UntilBeingAttacked);
-						if(ap.damageReceived > 0 && ap.attack.defender->unitId() == affected->unitId())
+						if(ap.attackerDamageReduce > 0 && ap.attack.defender->unitId() == affected->unitId())
 							swb->removeUnitBonus(Bonus::UntilAttack);
 					}
 				}
@@ -476,7 +515,7 @@ void CBattleAI::attemptCastingSpell()
 				auto bav = pt.bestActionValue();
 
 				//best action is from effective owner`s point if view, we need to convert to our point if view
-				if(state->battleGetOwner(unit) != playerID)
+				if(state.battleGetOwner(unit) != playerID)
 					bav = -bav;
 				values[unit->unitId()] += bav;
 			}
@@ -529,7 +568,7 @@ void CBattleAI::attemptCastingSpell()
 
 		HypotheticBattle state(env.get(), cb);
 
-		evaluateQueue(valueOfStack, turnOrder, &state, 0, &enemyHadTurn);
+		evaluateQueue(valueOfStack, turnOrder, state, 0, &enemyHadTurn);
 
 		if(!enemyHadTurn)
 		{
@@ -577,7 +616,7 @@ void CBattleAI::attemptCastingSpell()
 
 		state.battleGetTurnOrder(newTurnOrder, amount, 2);
 
-		const bool turnSpanOK = evaluateQueue(newValueOfStack, newTurnOrder, &state, minTurnSpan, nullptr);
+		const bool turnSpanOK = evaluateQueue(newValueOfStack, newTurnOrder, state, minTurnSpan, nullptr);
 
 		if(turnSpanOK || castNow)
 		{

+ 689 - 0
AI/BattleAI/BattleExchangeVariant.cpp

@@ -0,0 +1,689 @@
+/*
+ * BattleAI.cpp, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#include "StdInc.h"
+#include "BattleExchangeVariant.h"
+#include "../../lib/CStack.h"
+
+AttackerValue::AttackerValue()
+{
+	value = 0;
+	isRetalitated = false;
+}
+
+MoveTarget::MoveTarget()
+	: positions()
+{
+	score = EvaluationResult::INEFFECTIVE_SCORE;
+}
+
+int64_t BattleExchangeVariant::trackAttack(const AttackPossibility & ap, HypotheticBattle & state)
+{
+	auto affectedUnits = ap.affectedUnits;
+
+	affectedUnits.push_back(ap.attackerState);
+
+	for(auto affectedUnit : affectedUnits)
+	{
+		auto unitToUpdate = state.getForUpdate(affectedUnit->unitId());
+
+		unitToUpdate->health = affectedUnit->health;
+		unitToUpdate->shots = affectedUnit->shots;
+		unitToUpdate->counterAttacks = affectedUnit->counterAttacks;
+		unitToUpdate->movedThisRound = affectedUnit->movedThisRound;
+	}
+
+	auto attackValue = ap.attackValue();
+
+	dpsScore += attackValue;
+
+#if BATTLE_TRACE_LEVEL>=1
+	logAi->trace(
+		"%s -> %s, ap attack, %s, dps: %lld, score: %lld",
+		ap.attack.attacker->getDescription(),
+		ap.attack.defender->getDescription(),
+		ap.attack.shooting ? "shot" : "mellee",
+		ap.damageDealt,
+		attackValue);
+#endif
+
+	return attackValue;
+}
+
+int64_t BattleExchangeVariant::trackAttack(
+	std::shared_ptr<StackWithBonuses> attacker,
+	std::shared_ptr<StackWithBonuses> defender,
+	bool shooting,
+	bool isOurAttack,
+	const CBattleInfoCallback & cb,
+	bool evaluateOnly)
+{
+	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
+	static const auto selectorBlocksRetaliation = Selector::type()(Bonus::BLOCKS_RETALIATION);
+	const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
+
+	TDmgRange retaliation;
+	BattleAttackInfo bai(attacker.get(), defender.get(), shooting);
+
+	if(shooting)
+	{
+		bai.attackerPos.setXY(8, 5);
+	}
+
+	auto attack = cb.battleEstimateDamage(bai, &retaliation);
+	int64_t attackDamage = (attack.first + attack.second) / 2;
+	int64_t defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, cb);
+	int64_t attackerDamageReduce = 0;
+
+	if(!evaluateOnly)
+	{
+#if BATTLE_TRACE_LEVEL>=1
+		logAi->trace(
+			"%s -> %s, normal attack, %s, dps: %lld, %lld",
+			attacker->getDescription(),
+			defender->getDescription(),
+			shooting ? "shot" : "mellee",
+			attackDamage,
+			defenderDamageReduce);
+#endif
+
+		if(isOurAttack)
+		{
+			dpsScore += defenderDamageReduce;
+			attackerValue[attacker->unitId()].value += defenderDamageReduce;
+		}
+		else
+			dpsScore -= defenderDamageReduce;
+
+		defender->damage(attackDamage);
+		attacker->afterAttack(shooting, false);
+	}
+
+	if(defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
+	{
+		if(retaliation.second != 0)
+		{
+			auto retaliationDamage = (retaliation.first + retaliation.second) / 2;
+			attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, cb);
+
+			if(!evaluateOnly)
+			{
+#if BATTLE_TRACE_LEVEL>=1
+				logAi->trace(
+					"%s -> %s, retaliation, dps: %lld, %lld",
+					defender->getDescription(),
+					attacker->getDescription(),
+					retaliationDamage,
+					attackerDamageReduce);
+#endif
+
+				if(isOurAttack)
+				{
+					dpsScore -= attackerDamageReduce;
+					attackerValue[attacker->unitId()].isRetalitated = true;
+				}
+				else
+				{
+					dpsScore += attackerDamageReduce;
+					attackerValue[defender->unitId()].value += attackerDamageReduce;
+				}
+
+				attacker->damage(retaliationDamage);
+				defender->afterAttack(false, true);
+			}
+		}
+	}
+
+	auto score = defenderDamageReduce - attackerDamageReduce;
+
+#if BATTLE_TRACE_LEVEL>=1
+	if(!score)
+	{
+		logAi->trace("Attack has zero score d:%lld a:%lld", defenderDamageReduce, attackerDamageReduce);
+	}
+#endif
+
+	return score;
+}
+
+EvaluationResult BattleExchangeEvaluator::findBestTarget(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb)
+{
+	EvaluationResult result(targets.bestAction());
+
+	updateReachabilityMap(hb);
+
+	for(auto & ap : targets.possibleAttacks)
+	{
+		int64_t score = calculateExchange(ap, targets, hb);
+
+		if(score > result.score)
+		{
+			result.score = score;
+			result.bestAttack = ap;
+		}
+	}
+
+	if(!activeStack->waited())
+	{
+#if BATTLE_TRACE_LEVEL>=1
+		logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
+#endif
+
+		hb.getForUpdate(activeStack->unitId())->waiting = true;
+		hb.getForUpdate(activeStack->unitId())->waitedThisTurn = true;
+
+		updateReachabilityMap(hb);
+
+		for(auto & ap : targets.possibleAttacks)
+		{
+			int64_t score = calculateExchange(ap, targets, hb);
+
+			if(score > result.score)
+			{
+				result.score = score;
+				result.bestAttack = ap;
+				result.wait = true;
+			}
+		}
+	}
+
+	return result;
+}
+
+MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb)
+{
+	MoveTarget result;
+	BattleExchangeVariant ev;
+
+	if(targets.unreachableEnemies.empty())
+		return result;
+
+	updateReachabilityMap(hb);
+
+	auto dists = cb->getReachability(activeStack);
+	auto speed = activeStack->Speed();
+
+	for(const battle::Unit * enemy : targets.unreachableEnemies)
+	{
+		int64_t stackScore = EvaluationResult::INEFFECTIVE_SCORE;
+
+		std::vector<const battle::Unit *> adjacentStacks = getAdjacentUnits(enemy);
+		auto closestStack = *vstd::minElementByFun(adjacentStacks, [&](const battle::Unit * u) -> int64_t
+			{
+				return dists.distToNearestNeighbour(activeStack, u) * 100000 - activeStack->getTotalHealth();
+			});
+
+		auto distance = dists.distToNearestNeighbour(activeStack, closestStack);
+
+		if(distance >= GameConstants::BFIELD_SIZE)
+			continue;
+
+		if(distance <= speed)
+			continue;
+
+		auto turnsToRich = (distance - 1) / speed + 1;
+		auto hexes = closestStack->getSurroundingHexes();
+
+		for(auto hex : hexes)
+		{
+			auto bai = BattleAttackInfo(activeStack, closestStack, cb->battleCanShoot(activeStack));
+			auto attack = AttackPossibility::evaluate(bai, hex, hb);
+
+			attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
+
+			auto score = calculateExchange(attack, targets, hb) / turnsToRich;
+
+			if(result.score < score)
+			{
+				result.score = score;
+				result.positions = closestStack->getAttackableHexes(activeStack);
+			}
+		}
+	}
+
+	return result;
+}
+
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit)
+{
+	std::queue<const battle::Unit *> queue;
+	std::vector<const battle::Unit *> checkedStacks;
+
+	queue.push(blockerUnit);
+
+	while(!queue.empty())
+	{
+		auto stack = queue.front();
+
+		queue.pop();
+		checkedStacks.push_back(stack);
+
+		auto hexes = stack->getSurroundingHexes();
+		for(auto hex : hexes)
+		{
+			auto neighbor = cb->battleGetStackByPos(hex);
+
+			if(neighbor && neighbor->unitSide() == stack->unitSide() && !vstd::contains(checkedStacks, neighbor))
+			{
+				queue.push(neighbor);
+				checkedStacks.push_back(neighbor);
+			}
+		}
+	}
+
+	return checkedStacks;
+}
+
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
+	const AttackPossibility & ap,
+	PotentialTargets & targets,
+	HypotheticBattle & hb)
+{
+	auto hexes = ap.attack.defender->getHexes();
+
+	if(!ap.attack.shooting) hexes.push_back(ap.from);
+
+	std::vector<const battle::Unit *> exchangeUnits;
+	std::vector<const battle::Unit *> allReachableUnits;
+
+	for(auto hex : hexes)
+	{
+		vstd::concatenate(allReachableUnits, reachabilityMap[hex]);
+	}
+
+	vstd::removeDuplicates(allReachableUnits);
+
+	auto copy = allReachableUnits;
+	for(auto unit : copy)
+	{
+		for(auto adjacentUnit : getAdjacentUnits(unit))
+		{
+			auto unitWithBonuses = hb.battleGetUnitByID(adjacentUnit->unitId());
+
+			if(vstd::contains(targets.unreachableEnemies, adjacentUnit)
+				&& !vstd::contains(allReachableUnits, unitWithBonuses))
+			{
+				allReachableUnits.push_back(unitWithBonuses);
+			}
+		}
+	}
+
+	vstd::removeDuplicates(allReachableUnits);
+
+	if(!vstd::contains(allReachableUnits, ap.attack.attacker))
+	{
+		allReachableUnits.push_back(ap.attack.attacker);
+	}
+
+	if(allReachableUnits.size() < 2)
+	{
+#if BATTLE_TRACE_LEVEL>=1
+		logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
+#endif
+
+		return exchangeUnits;
+	}
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		for(auto unit : turnOrder[turn])
+		{
+			if(vstd::contains(allReachableUnits, unit))
+				exchangeUnits.push_back(unit);
+		}
+	}
+
+	return exchangeUnits;
+}
+
+int64_t BattleExchangeEvaluator::calculateExchange(
+	const AttackPossibility & ap,
+	PotentialTargets & targets,
+	HypotheticBattle & hb)
+{
+#if BATTLE_TRACE_LEVEL>=1
+	logAi->trace("Battle exchange at %lld", ap.attack.shooting ? ap.dest : ap.from);
+#endif
+
+	std::vector<const battle::Unit *> ourStacks;
+	std::vector<const battle::Unit *> enemyStacks;
+
+	enemyStacks.push_back(ap.attack.defender);
+
+	std::vector<const battle::Unit *> exchangeUnits = getExchangeUnits(ap, targets, hb);
+
+	if(exchangeUnits.empty())
+	{
+		return 0;
+	}
+
+	HypotheticBattle exchangeBattle(env.get(), cb);
+	BattleExchangeVariant v;
+	auto melleeAttackers = ourStacks;
+
+	vstd::removeDuplicates(melleeAttackers);
+	vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
+		{
+			return !cb->battleCanShoot(u);
+		});
+
+	for(auto unit : exchangeUnits)
+	{
+		bool isOur = cb->battleMatchOwner(ap.attack.attacker, unit, true);
+		auto & attackerQueue = isOur ? ourStacks : enemyStacks;
+
+		if(!vstd::contains(attackerQueue, unit))
+		{
+			attackerQueue.push_back(unit);
+		}
+	}
+
+	bool canUseAp = true;
+
+	for(auto activeUnit : exchangeUnits)
+	{
+		bool isOur = cb->battleMatchOwner(ap.attack.attacker, activeUnit, true);
+		battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
+		battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
+
+		auto attacker = exchangeBattle.getForUpdate(activeUnit->unitId());
+
+		if(!attacker->alive())
+		{
+#if BATTLE_TRACE_LEVEL>=1
+			logAi->trace(	"Attacker is dead");
+#endif
+
+			continue;
+		}
+
+		auto targetUnit = ap.attack.defender;
+
+		if(!isOur || !exchangeBattle.getForUpdate(targetUnit->unitId())->alive())
+		{
+			auto estimateAttack = [&](const battle::Unit * u) -> int64_t
+			{
+				auto stackWithBonuses = exchangeBattle.getForUpdate(u->unitId());
+				auto score = v.trackAttack(
+					attacker,
+					stackWithBonuses,
+					exchangeBattle.battleCanShoot(stackWithBonuses.get()),
+					isOur,
+					*cb,
+					true);
+
+#if BATTLE_TRACE_LEVEL>=1
+				logAi->trace("Best target selector %s->%s score = %lld", attacker->getDescription(), u->getDescription(), score);
+#endif
+
+				return score;
+			};
+
+			if(!oppositeQueue.empty())
+			{
+				targetUnit = *vstd::maxElementByFun(oppositeQueue, estimateAttack);
+			}
+			else
+			{
+				auto reachable = exchangeBattle.battleGetUnitsIf([&](const battle::Unit * u) -> bool
+					{
+						if(!u->alive() || u->unitSide() == attacker->unitSide())
+							return false;
+
+						return vstd::contains_if(reachabilityMap[u->getPosition()], [&](const battle::Unit * other) -> bool
+							{
+								return attacker->unitId() == other->unitId();
+							});
+					});
+
+				if(!reachable.empty())
+				{
+					targetUnit = *vstd::maxElementByFun(reachable, estimateAttack);
+				}
+				else
+				{
+#if BATTLE_TRACE_LEVEL>=1
+					logAi->trace("Battle queue is empty and no reachable enemy.");
+#endif
+
+					continue;
+				}
+			}
+		}
+
+		auto defender = exchangeBattle.getForUpdate(targetUnit->unitId());
+		auto shooting = cb->battleCanShoot(attacker.get());
+		const int totalAttacks = attacker->getTotalAttacks(shooting);
+
+		if(canUseAp && activeUnit == ap.attack.attacker && targetUnit == ap.attack.defender)
+		{
+			v.trackAttack(ap, exchangeBattle);
+		}
+		else
+		{
+			for(int i = 0; i < totalAttacks; i++)
+			{
+				v.trackAttack(attacker, defender, shooting, isOur, exchangeBattle);
+
+				if(!attacker->alive() || !defender->alive())
+					break;
+			}
+		}
+
+		canUseAp = false;
+
+		vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
+			{
+				return !exchangeBattle.getForUpdate(u->unitId())->alive();
+			});
+
+		vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
+			{
+				return !exchangeBattle.getForUpdate(u->unitId())->alive();
+			});
+	}
+
+	// avoid blocking path for stronger stack by weaker stack
+	// the method checks if all stacks can be placed around enemy
+	v.adjustPositions(melleeAttackers, ap, reachabilityMap);
+
+#if BATTLE_TRACE_LEVEL>=1
+	logAi->trace("Exchange score: %lld", v.getScore());
+#endif
+
+	return v.getScore();
+}
+
+void BattleExchangeVariant::adjustPositions(
+	std::vector<const battle::Unit*> attackers,
+	const AttackPossibility & ap,
+	std::map<BattleHex, battle::Units> & reachabilityMap)
+{
+	auto hexes = ap.attack.defender->getSurroundingHexes();
+
+	boost::sort(attackers, [&](const battle::Unit * u1, const battle::Unit * u2) -> bool
+		{
+			if(attackerValue[u1->unitId()].isRetalitated && !attackerValue[u2->unitId()].isRetalitated)
+				return true;
+
+			if(attackerValue[u2->unitId()].isRetalitated && !attackerValue[u1->unitId()].isRetalitated)
+				return false;
+
+			return attackerValue[u1->unitId()].value > attackerValue[u2->unitId()].value;
+		});
+
+	if(!ap.attack.shooting)
+	{
+		vstd::erase_if_present(hexes, ap.from);
+		vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos));
+	}
+
+	int64_t notRealizedDamage = 0;
+
+	for(auto unit : attackers)
+	{
+		if(unit->unitId() == ap.attack.attacker->unitId())
+			continue;
+
+		if(!vstd::contains_if(hexes, [&](BattleHex h) -> bool
+			{
+				return vstd::contains(reachabilityMap[h], unit);
+			}))
+		{
+			notRealizedDamage += attackerValue[unit->unitId()].value;
+			continue;
+		}
+
+		auto desiredPosition = vstd::minElementByFun(hexes, [&](BattleHex h) -> int64_t
+			{
+				auto score = vstd::contains(reachabilityMap[h], unit)
+					? reachabilityMap[h].size()
+					: 0;
+
+				if(unit->doubleWide())
+				{
+					auto backHex = unit->occupiedHex(h);
+
+					if(vstd::contains(hexes, backHex))
+						score += reachabilityMap[backHex].size();
+				}
+
+				return score;
+			});
+
+		hexes.erase(desiredPosition);
+	}
+
+	if(notRealizedDamage > ap.attackValue() && notRealizedDamage > attackerValue[ap.attack.attacker->unitId()].value)
+	{
+		dpsScore = EvaluationResult::INEFFECTIVE_SCORE;
+	}
+}
+
+void BattleExchangeEvaluator::updateReachabilityMap(HypotheticBattle & hb)
+{
+	const int TURN_DEPTH = 2;
+
+	turnOrder.clear();
+	
+	hb.battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
+	reachabilityMap.clear();
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		auto & turnQueue = turnOrder[turn];
+		HypotheticBattle turnBattle(env.get(), cb);
+
+		for(const battle::Unit * unit : turnQueue)
+		{
+			if(turnBattle.battleCanShoot(unit))
+			{
+				for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+				{
+					reachabilityMap[hex].push_back(unit);
+				}
+
+				continue;
+			}
+
+			auto unitReachability = turnBattle.getReachability(unit);
+
+			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			{
+				bool reachable = unitReachability.distances[hex] <= unit->Speed(turn);
+
+				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				{
+					const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
+
+					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					{
+						for(BattleHex neighbor : hex.neighbouringTiles())
+						{
+							reachable = unitReachability.distances[neighbor] <= unit->Speed(turn);
+
+							if(reachable) break;
+						}
+					}
+				}
+
+				if(reachable)
+				{
+					reachabilityMap[hex].push_back(unit);
+				}
+			}
+		}
+	}
+}
+
+// avoid blocking path for stronger stack by weaker stack
+bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
+{
+	const int BLOCKING_THRESHOLD = 70;
+	const int BLOCKING_OWN_ATTACK_PENALTY = 100;
+	const int BLOCKING_OWN_MOVE_PENALTY = 1;
+
+	float blockingScore = 0;
+
+	auto activeUnitDamage = activeUnit->getMinDamage(hb.battleCanShoot(activeUnit)) * activeUnit->getCount();
+
+	for(int turn = 0; turn < turnOrder.size(); turn++)
+	{
+		auto & turnQueue = turnOrder[turn];
+		HypotheticBattle turnBattle(env.get(), cb);
+
+		auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
+		unitToUpdate->setPosition(position);
+
+		for(const battle::Unit * unit : turnQueue)
+		{
+			if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
+				continue;
+
+			auto blockedUnitDamage = unit->getMinDamage(hb.battleCanShoot(unit)) * unit->getCount();
+			auto ratio = blockedUnitDamage / (blockedUnitDamage + activeUnitDamage);
+
+			auto unitReachability = turnBattle.getReachability(unit);
+
+			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			{
+				bool enemyUnit = false;
+				bool reachable = unitReachability.distances[hex] <= unit->Speed(turn);
+
+				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				{
+					const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
+
+					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					{
+						enemyUnit = true;
+
+						for(BattleHex neighbor : hex.neighbouringTiles())
+						{
+							reachable = unitReachability.distances[neighbor] <= unit->Speed(turn);
+
+							if(reachable) break;
+						}
+					}
+				}
+
+				if(!reachable && vstd::contains(reachabilityMap[hex], unit))
+				{
+					blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
+				}
+			}
+		}
+	}
+
+#if BATTLE_TRACE_LEVEL>=1
+	logAi->trace("Position %d, blocking score %f", position.hex, blockingScore);
+#endif
+
+	return blockingScore > BLOCKING_THRESHOLD;
+}

+ 107 - 0
AI/BattleAI/BattleExchangeVariant.h

@@ -0,0 +1,107 @@
+/*
+ * BattleExchangeVariant.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#pragma once
+
+#include "../../lib/AI_Base.h"
+#include "../../lib/battle/ReachabilityInfo.h"
+#include "PotentialTargets.h"
+#include "StackWithBonuses.h"
+
+struct AttackerValue
+{
+	int64_t value;
+	bool isRetalitated;
+	BattleHex position;
+
+	AttackerValue();
+};
+
+struct MoveTarget
+{
+	int64_t score;
+	std::vector<BattleHex> positions;
+
+	MoveTarget();
+};
+
+struct EvaluationResult
+{
+	static const int64_t INEFFECTIVE_SCORE = -1000000;
+
+	AttackPossibility bestAttack;
+	MoveTarget bestMove;
+	bool wait;
+	int64_t score;
+	bool defend;
+
+	EvaluationResult(const AttackPossibility & ap)
+		:wait(false), score(0), bestAttack(ap), defend(false)
+	{
+	}
+};
+
+/// <summary>
+/// The class represents evaluation of attack value
+/// of exchanges between all stacks which can access particular hex
+/// starting from initial attack represented by AttackPossibility and further according turn order.
+/// Negative score value means we get more demage than deal
+/// </summary>
+class BattleExchangeVariant
+{
+public:
+	BattleExchangeVariant()
+		:dpsScore(0), attackerValue()
+	{
+	}
+
+	int64_t trackAttack(const AttackPossibility & ap, HypotheticBattle & state);
+
+	int64_t trackAttack(
+		std::shared_ptr<StackWithBonuses> attacker,
+		std::shared_ptr<StackWithBonuses> defender,
+		bool shooting,
+		bool isOurAttack,
+		const CBattleInfoCallback & cb,
+		bool evaluateOnly = false);
+
+	int64_t getScore() const { return dpsScore; }
+
+	void adjustPositions(
+		std::vector<const battle::Unit *> attackers,
+		const AttackPossibility & ap,
+		std::map<BattleHex, battle::Units> & reachabilityMap);
+
+private:
+	int64_t dpsScore;
+	std::map<uint32_t, AttackerValue> attackerValue;
+};
+
+class BattleExchangeEvaluator
+{
+private:
+	std::shared_ptr<CBattleInfoCallback> cb;
+	std::shared_ptr<Environment> env;
+	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
+	std::vector<battle::Units> turnOrder;
+
+public:
+	BattleExchangeEvaluator(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env)
+		:cb(cb), reachabilityMap(), env(env), turnOrder()
+	{
+	}
+
+	EvaluationResult findBestTarget(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb);
+	int64_t calculateExchange(const AttackPossibility & ap, PotentialTargets & targets, HypotheticBattle & hb);
+	void updateReachabilityMap(HypotheticBattle & hb);
+	std::vector<const battle::Unit *> getExchangeUnits(const AttackPossibility & ap, PotentialTargets & targets, HypotheticBattle & hb);
+	bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
+	MoveTarget findMoveTowardsUnreachable(const battle::Unit * activeStack, PotentialTargets & targets, HypotheticBattle & hb);
+	std::vector<const battle::Unit *> getAdjacentUnits(const battle::Unit * unit);
+};

+ 2 - 0
AI/BattleAI/CMakeLists.txt

@@ -10,6 +10,7 @@ set(battleAI_SRCS
 		PotentialTargets.cpp
 		StackWithBonuses.cpp
 		ThreatMap.cpp
+		BattleExchangeVariant.cpp
 )
 
 set(battleAI_HEADERS
@@ -23,6 +24,7 @@ set(battleAI_HEADERS
 		PossibleSpellcast.h
 		StackWithBonuses.h
 		ThreatMap.h
+		BattleExchangeVariant.h
 )
 
 assign_source_group(${battleAI_SRCS} ${battleAI_HEADERS})

+ 17 - 22
AI/BattleAI/PotentialTargets.cpp

@@ -11,13 +11,11 @@
 #include "PotentialTargets.h"
 #include "../../lib/CStack.h"//todo: remove
 
-PotentialTargets::PotentialTargets(const battle::Unit * attacker, const HypotheticBattle * state)
+PotentialTargets::PotentialTargets(const battle::Unit * attacker, const HypotheticBattle & state)
 {
-	auto attIter = state->stackStates.find(attacker->unitId());
-	const battle::Unit * attackerInfo = (attIter == state->stackStates.end()) ? attacker : attIter->second.get();
-
-	auto reachability = state->getReachability(attackerInfo);
-	auto avHexes = state->battleGetAvailableHexes(reachability, attackerInfo);
+	auto attackerInfo = state.battleGetUnitByID(attacker->unitId());
+	auto reachability = state.getReachability(attackerInfo);
+	auto avHexes = state.battleGetAvailableHexes(reachability, attackerInfo);
 
 	//FIXME: this should part of battleGetAvailableHexes
 	bool forceTarget = false;
@@ -27,7 +25,7 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 	if(attackerInfo->hasBonusOfType(Bonus::ATTACKS_NEAREST_CREATURE))
 	{
 		forceTarget = true;
-		auto nearest = state->getNearestStack(attackerInfo);
+		auto nearest = state.getNearestStack(attackerInfo);
 
 		if(nearest.first != nullptr)
 		{
@@ -36,14 +34,14 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 		}
 	}
 
-	auto aliveUnits = state->battleGetUnitsIf([=](const battle::Unit * unit)
+	auto aliveUnits = state.battleGetUnitsIf([=](const battle::Unit * unit)
 	{
 		return unit->isValidTarget() && unit->unitId() != attackerInfo->unitId();
 	});
 
 	for(auto defender : aliveUnits)
 	{
-		if(!forceTarget && !state->battleMatchOwner(attackerInfo, defender))
+		if(!forceTarget && !state.battleMatchOwner(attackerInfo, defender))
 			continue;
 
 		auto GenerateAttackInfo = [&](bool shooting, BattleHex hex) -> AttackPossibility
@@ -63,7 +61,7 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 			else
 				unreachableEnemies.push_back(defender);
 		}
-		else if(state->battleCanShoot(attackerInfo, defender->getPosition()))
+		else if(state.battleCanShoot(attackerInfo, defender->getPosition()))
 		{
 			possibleAttacks.push_back(GenerateAttackInfo(true, BattleHex::INVALID));
 		}
@@ -86,22 +84,18 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 
 	boost::sort(possibleAttacks, [](const AttackPossibility & lhs, const AttackPossibility & rhs) -> bool
 	{
-		if(lhs.collateralDamage > rhs.collateralDamage)
-			return false;
-		if(lhs.collateralDamage < rhs.collateralDamage)
-			return true;
-		return (lhs.damageDealt + lhs.shootersBlockedDmg - lhs.damageReceived > rhs.damageDealt + rhs.shootersBlockedDmg - rhs.damageReceived);
+		return lhs.damageDiff() > rhs.damageDiff();
 	});
 
 	if (!possibleAttacks.empty())
 	{
-		auto &bestAp = possibleAttacks[0];
+		auto & bestAp = possibleAttacks[0];
 
-		logGlobal->info("Battle AI best: %s -> %s at %d from %d, affects %d units: %lld %lld %lld %lld",
+		logGlobal->info("Battle AI best: %s -> %s at %d from %d, affects %d units: d:%lld a:%lld c:%lld s:%lld",
 			bestAp.attack.attacker->unitType()->identifier,
-			state->battleGetUnitByPos(bestAp.dest)->unitType()->identifier,
+			state.battleGetUnitByPos(bestAp.dest)->unitType()->identifier,
 			(int)bestAp.dest, (int)bestAp.from, (int)bestAp.affectedUnits.size(),
-			bestAp.damageDealt, bestAp.damageReceived, bestAp.collateralDamage, bestAp.shootersBlockedDmg);
+			bestAp.defenderDamageReduce, bestAp.attackerDamageReduce, bestAp.collateralDamageReduce, bestAp.shootersBlockedDmg);
 	}
 }
 
@@ -109,13 +103,14 @@ int64_t PotentialTargets::bestActionValue() const
 {
 	if(possibleAttacks.empty())
 		return 0;
+
 	return bestAction().attackValue();
 }
 
-AttackPossibility PotentialTargets::bestAction() const
+const AttackPossibility & PotentialTargets::bestAction() const
 {
 	if(possibleAttacks.empty())
 		throw std::runtime_error("No best action, since we don't have any actions");
-	return possibleAttacks[0];
-	//return *vstd::maxElementByFun(possibleAttacks, [](const AttackPossibility &ap) { return ap.attackValue(); } );
+
+	return possibleAttacks.front();
 }

+ 2 - 2
AI/BattleAI/PotentialTargets.h

@@ -17,8 +17,8 @@ public:
 	std::vector<const battle::Unit *> unreachableEnemies;
 
 	PotentialTargets(){};
-	PotentialTargets(const battle::Unit * attacker, const HypotheticBattle * state);
+	PotentialTargets(const battle::Unit * attacker, const HypotheticBattle & state);
 
-	AttackPossibility bestAction() const;
+	const AttackPossibility & bestAction() const;
 	int64_t bestActionValue() const;
 };

+ 15 - 1
AI/BattleAI/StackWithBonuses.cpp

@@ -199,6 +199,21 @@ void StackWithBonuses::removeUnitBonus(const CSelector & selector)
 	vstd::erase_if(bonusesToUpdate, [&](const Bonus & b){return selector(&b);});
 }
 
+std::string StackWithBonuses::getDescription() const
+{
+	std::ostringstream oss;
+	oss << unitOwner().getStr();
+	oss << " battle stack [" << unitId() << "]: " << getCount() << " of ";
+	if(type)
+		oss << type->namePl;
+	else
+		oss << "[UNDEFINED TYPE]";
+
+	oss << " from slot " << slot;
+
+	return oss.str();
+}
+
 void StackWithBonuses::spendMana(ServerCallback * server, const int spellCost) const
 {
 	//TODO: evaluate cast use
@@ -284,7 +299,6 @@ int32_t HypotheticBattle::getActiveStackID() const
 void HypotheticBattle::nextRound(int32_t roundNr)
 {
 	//TODO:HypotheticBattle::nextRound
-
 	for(auto unit : battleAliveUnits())
 	{
 		auto forUpdate = getForUpdate(unit->unitId());

+ 1 - 0
AI/BattleAI/StackWithBonuses.h

@@ -85,6 +85,7 @@ public:
 	void removeUnitBonus(const CSelector & selector);
 
 	void spendMana(ServerCallback * server, const int spellCost) const override;
+	std::string getDescription() const override;
 
 private:
 	const IBonusBearer * origBearer;

+ 6 - 1
AI/Nullkiller/AIGateway.cpp

@@ -27,7 +27,9 @@
 namespace NKAI
 {
 
+// our to enemy strength ratio constants
 const float SAFE_ATTACK_CONSTANT = 1.2;
+const float RETREAT_THRESHOLD = 0.3;
 
 //one thread may be turn of AI and another will be handling a side effect for AI2
 boost::thread_specific_ptr<CCallback> cb;
@@ -202,6 +204,9 @@ void AIGateway::gameOver(PlayerColor player, const EVictoryLossCheckResult & vic
 			logAi->debug("AIGateway: Player %d (%s) lost. It's me. What a disappointment! :(", player, player.getStr());
 		}
 
+		// some whitespace to flush stream
+		logAi->debug(std::string(200, ' '));
+
 		finish();
 	}
 }
@@ -498,7 +503,7 @@ boost::optional<BattleAction> AIGateway::makeSurrenderRetreatDecision(
 	double fightRatio = battleState.getOurStrength() / (double)battleState.getEnemyStrength();
 
 	// if we have no towns - things are already bad, so retreat is not an option.
-	if(cb->getTownsInfo().size() && fightRatio < 0.3 && battleState.canFlee)
+	if(cb->getTownsInfo().size() && fightRatio < RETREAT_THRESHOLD && battleState.canFlee)
 	{
 		return BattleAction::makeRetreat(battleState.ourSide);
 	}

+ 5 - 4
AI/Nullkiller/Pathfinding/AINodeStorage.cpp

@@ -180,13 +180,14 @@ std::vector<CGPathNode *> AINodeStorage::getInitialNodes()
 	for(auto actorPtr : actors)
 	{
 		ChainActor * actor = actorPtr.get();
-		AIPathNode * initialNode =
-			getOrCreateNode(actor->initialPosition, actor->layer, actor)
-			.get();
 
-		if(!initialNode)
+		auto allocated = getOrCreateNode(actor->initialPosition, actor->layer, actor);
+
+		if(!allocated)
 			continue;
 
+		AIPathNode * initialNode = allocated.get();
+
 		initialNode->inPQ = false;
 		initialNode->pq = nullptr;
 		initialNode->turns = actor->initialTurn;

+ 6 - 1
lib/battle/ReachabilityInfo.cpp

@@ -70,7 +70,12 @@ int ReachabilityInfo::distToNearestNeighbour(
 	const battle::Unit * defender,
 	BattleHex * chosenHex) const
 {
-	auto attackableHexes = defender->getAttackableHexes(attacker);
+	auto attackableHexes = defender->getHexes();
+
+	if(attacker->doubleWide())
+	{
+		vstd::concatenate(attackableHexes, battle::Unit::getHexes(defender->occupiedHex(), true, attacker->unitSide()));
+	}
 
 	return distToNearestNeighbour(attackableHexes, chosenHex);
 }