Browse Source

Some changes to make the battle AI smarter

- the AI will now consider attacking multiple units
- the preferred strategy now is to minimize collateral damage rather than to maximize damage to enemy units alone
- attacks that block enemy shooters will be prioritized over other attacks in cases when shooters have weaker melee attacks
Victor Luchits 6 years ago
parent
commit
be10694b73

+ 187 - 96
AI/BattleAI/AttackPossibility.cpp

@@ -1,97 +1,188 @@
-/*
- * AttackPossibility.cpp, part of VCMI engine
- *
- * Authors: listed in file AUTHORS in main folder
- *
- * License: GNU General Public License v2.0 or later
- * Full text of license available in license.txt file, in main folder
- *
- */
-#include "StdInc.h"
-#include "AttackPossibility.h"
-
-AttackPossibility::AttackPossibility(BattleHex tile_, const BattleAttackInfo & attack_)
-	: tile(tile_),
-	attack(attack_)
-{
-}
-
-
-int64_t AttackPossibility::damageDiff() const
-{
-	//TODO: use target priority from HypotheticBattle
-	const auto dealtDmgValue = damageDealt;
-	const auto receivedDmgValue = damageReceived;
-
-	int64_t diff = 0;
-
-	//friendly fire or not
-	if(attack.attacker->unitSide() == attack.defender->unitSide())
-		diff = -dealtDmgValue - receivedDmgValue;
-	else
-		diff = dealtDmgValue - receivedDmgValue;
-
-	//mind control
-	auto actualSide = getCbc()->playerToSide(getCbc()->battleGetOwner(attack.attacker));
-	if(actualSide && actualSide.get() != attack.attacker->unitSide())
-		diff = -diff;
-	return diff;
-}
-
-int64_t AttackPossibility::attackValue() const
-{
-	return damageDiff() + tacticImpact;
-}
-
-AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInfo, BattleHex hex)
-{
-	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
-	static const auto selectorBlocksRetaliation = Selector::type(Bonus::BLOCKS_RETALIATION);
-
-	const bool counterAttacksBlocked = attackInfo.attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
-
-	AttackPossibility ap(hex, attackInfo);
-
-	ap.attackerState = attackInfo.attacker->acquireState();
-
-	const int totalAttacks = ap.attackerState->getTotalAttacks(attackInfo.shooting);
-
-	if(!attackInfo.shooting)
-		ap.attackerState->setPosition(hex);
-
-	auto defenderState = attackInfo.defender->acquireState();
-	ap.affectedUnits.push_back(defenderState);
-
-	for(int i = 0; i < totalAttacks; i++)
+/*
+ * AttackPossibility.cpp, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#include "StdInc.h"
+#include "AttackPossibility.h"
+#include "../../lib/CStack.h"//todo: remove
+
+AttackPossibility::AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack)
+	: from(from), dest(dest), attack(attack)
+{
+}
+
+int64_t AttackPossibility::damageDiff() const
+{
+	return damageDealt - damageReceived - collateralDamage + shootersBlockedDmg;
+}
+
+int64_t AttackPossibility::attackValue() const
+{
+	return damageDiff();
+}
+
+int64_t AttackPossibility::evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state)
+{
+	int64_t res = 0;
+
+	if(attackInfo.shooting)
+		return 0;
+
+	auto attacker = attackInfo.attacker;
+	auto hexes = attacker->getSurroundingHexes(hex);
+	for(BattleHex tile : hexes)
+	{
+		auto st = state->battleGetUnitByPos(tile, true);
+		if(!st || !state->battleMatchOwner(st, attacker))
+			continue;
+		if(!state->battleCanShoot(st))
+			continue;
+
+		BattleAttackInfo rangeAttackInfo(st, attacker, true);
+		rangeAttackInfo.defenderPos = hex;
+
+		BattleAttackInfo meleeAttackInfo(st, attacker, false);
+		meleeAttackInfo.defenderPos = hex;
+
+		auto rangeDmg = getCbc()->battleEstimateDamage(rangeAttackInfo);
+		auto meleeDmg = getCbc()->battleEstimateDamage(meleeAttackInfo);
+
+		int64_t gain = (rangeDmg.first + rangeDmg.second - meleeDmg.first - meleeDmg.second) / 2 + 1;
+		res += gain;
+	}
+
+	return res;
+}
+
+AttackPossibility AttackPossibility::evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state)
+{
+	auto attacker = attackInfo.attacker;
+	auto defender = attackInfo.defender;
+	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
+	static const auto selectorBlocksRetaliation = Selector::type(Bonus::BLOCKS_RETALIATION);
+	const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
+	const bool mindControlled = [&](const battle::Unit *attacker) -> bool
 	{
-		TDmgRange retaliation(0,0);
-		auto attackDmg = getCbc()->battleEstimateDamage(ap.attack, &retaliation);
-
-		vstd::amin(attackDmg.first, defenderState->getAvailableHealth());
-		vstd::amin(attackDmg.second, defenderState->getAvailableHealth());
-
-		vstd::amin(retaliation.first, ap.attackerState->getAvailableHealth());
-		vstd::amin(retaliation.second, ap.attackerState->getAvailableHealth());
-
-		ap.damageDealt += (attackDmg.first + attackDmg.second) / 2;
-
-		ap.attackerState->afterAttack(attackInfo.shooting, false);
-
-		//FIXME: use ranged retaliation
-		if(!attackInfo.shooting && defenderState->ableToRetaliate() && !counterAttacksBlocked)
-		{
-			ap.damageReceived += (retaliation.first + retaliation.second) / 2;
-			defenderState->afterAttack(attackInfo.shooting, true);
-		}
-
-		ap.attackerState->damage(ap.damageReceived);
-		defenderState->damage(ap.damageDealt);
-
-		if(!ap.attackerState->alive() || !defenderState->alive())
-			break;
-	}
-
-	//TODO other damage related to attack (eg. fire shield and other abilities)
-
-	return ap;
-}
+		auto actualSide = getCbc()->playerToSide(getCbc()->battleGetOwner(attacker));
+		if (actualSide && actualSide.get() != attacker->unitSide())
+			return true;
+		return false;
+	} (attacker);
+
+	AttackPossibility bestAp(hex, BattleHex::INVALID, attackInfo);
+
+	std::vector<BattleHex> defenderHex;
+	if(attackInfo.shooting) {
+		defenderHex = defender->getHexes();
+	} else {
+		defenderHex = CStack::meleeAttackHexes(attacker, defender, hex);
+	}
+
+	for(BattleHex defHex : defenderHex) {
+		if(defHex == hex) {
+			// should be impossible but check anyway
+			continue;
+		}
+
+		AttackPossibility ap(hex, defHex, attackInfo);
+		ap.attackerState = attacker->acquireState();
+		ap.shootersBlockedDmg = bestAp.shootersBlockedDmg;
+
+		const int totalAttacks = ap.attackerState->getTotalAttacks(attackInfo.shooting);
+
+		if (!attackInfo.shooting)
+			ap.attackerState->setPosition(hex);
+
+		std::vector<const battle::Unit*> units;
+
+		if (attackInfo.shooting)
+			units = state->getAttackedBattleUnits(attacker, defHex, true, BattleHex::INVALID);
+		else
+			units = state->getAttackedBattleUnits(attacker, defHex, false, hex);
+
+		// ensure the defender is also affected
+		bool addDefender = true;
+		for(auto unit : units) {
+			if (unit->unitId() == defender->unitId()) {
+				addDefender = false;
+				break;
+			}
+		}
+		if(addDefender) {
+			units.push_back(defender);
+		}
+
+		for(auto u : units) {
+			if(!ap.attackerState->alive()) {
+				break;
+			}
+
+			assert(u->unitId() != attacker->unitId());
+
+			auto defenderState = u->acquireState();
+			ap.affectedUnits.push_back(defenderState);
+
+			for(int i = 0; i < totalAttacks; i++) {
+				si64 damageDealt, damageReceived;
+
+				TDmgRange retaliation(0, 0);
+				auto attackDmg = getCbc()->battleEstimateDamage(ap.attack, &retaliation);
+
+				vstd::amin(attackDmg.first, defenderState->getAvailableHealth());
+				vstd::amin(attackDmg.second, defenderState->getAvailableHealth());
+
+				vstd::amin(retaliation.first, ap.attackerState->getAvailableHealth());
+				vstd::amin(retaliation.second, ap.attackerState->getAvailableHealth());
+
+				damageDealt = (attackDmg.first + attackDmg.second) / 2;
+				ap.attackerState->afterAttack(attackInfo.shooting, false);
+
+				//FIXME: use ranged retaliation
+				damageReceived = 0;
+				if (!attackInfo.shooting && defenderState->ableToRetaliate() && !counterAttacksBlocked)
+				{
+					damageReceived = (retaliation.first + retaliation.second) / 2;
+					defenderState->afterAttack(attackInfo.shooting, true);
+				}
+
+				bool isEnemy = state->battleMatchOwner(attacker, u) && !mindControlled;
+				if(isEnemy)
+					ap.damageDealt += damageDealt;
+				else // friendly fire
+					ap.collateralDamage += damageDealt;
+
+				if(u->unitId() == defender->unitId() || 
+					(!attackInfo.shooting && CStack::isMeleeAttackPossible(u, attacker, hex))) { //FIXME: handle RANGED_RETALIATION ?
+					ap.damageReceived += damageReceived;
+				}
+
+				ap.attackerState->damage(damageReceived);
+				defenderState->damage(damageDealt);
+
+				if (!ap.attackerState->alive() || !defenderState->alive())
+					break;
+			}
+		}
+
+		if(!bestAp.dest.isValid() || ap.attackValue() > bestAp.attackValue()) {
+			bestAp = ap;
+		}
+	}
+
+	// check how much damage we gain from blocking enemy shooters on this hex
+	bestAp.shootersBlockedDmg = evaluateBlockedShootersDmg(attackInfo, hex, state);
+
+	logAi->debug("BattleAI best AP: %s -> %s at %d from %d, affects %d units: %d %d %d %s",
+		VLC->creh->creatures.at(attackInfo.attacker->acquireState()->creatureId())->identifier.c_str(),
+		VLC->creh->creatures.at(attackInfo.defender->acquireState()->creatureId())->identifier.c_str(),
+		(int)bestAp.dest, (int)bestAp.from, (int)bestAp.affectedUnits.size(),
+		(int)bestAp.damageDealt, (int)bestAp.damageReceived, (int)bestAp.collateralDamage, (int)bestAp.shootersBlockedDmg);
+
+	//TODO other damage related to attack (eg. fire shield and other abilities)
+	return bestAp;
+}

+ 9 - 4
AI/BattleAI/AttackPossibility.h

@@ -16,7 +16,8 @@
 class AttackPossibility
 {
 public:
-	BattleHex tile; //tile from which we attack
+	BattleHex from; //tile from which we attack
+	BattleHex dest; //tile which we attack
 	BattleAttackInfo attack;
 
 	std::shared_ptr<battle::CUnitState> attackerState;
@@ -25,12 +26,16 @@ public:
 
 	int64_t damageDealt = 0;
 	int64_t damageReceived = 0; //usually by counter-attack
-	int64_t tacticImpact = 0;
+	int64_t collateralDamage = 0; // friendly fire (usually by two-hex attacks)
+	int64_t shootersBlockedDmg = 0;
 
-	AttackPossibility(BattleHex tile_, const BattleAttackInfo & attack_);
+	AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack_);
 
 	int64_t damageDiff() const;
 	int64_t attackValue() const;
 
-	static AttackPossibility evaluate(const BattleAttackInfo & attackInfo, BattleHex hex);
+	static AttackPossibility evaluate(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state);
+
+private:
+	static int64_t evaluateBlockedShootersDmg(const BattleAttackInfo & attackInfo, BattleHex hex, const HypotheticBattle * state);
 };

+ 14 - 3
AI/BattleAI/BattleAI.cpp

@@ -154,7 +154,8 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 		HypotheticBattle hb(getCbc());
 
 		PotentialTargets targets(stack, &hb);
-		if(targets.possibleAttacks.size())
+
+		if(!targets.possibleAttacks.empty())
 		{
 			AttackPossibility bestAttack = targets.bestAction();
 
@@ -163,8 +164,18 @@ BattleAction CBattleAI::activeStack( const CStack * stack )
 				return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
 			else if(bestAttack.attack.shooting)
 				return BattleAction::makeShotAttack(stack, bestAttack.attack.defender);
-			else
-				return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.tile);
+			else {
+				auto &target = bestAttack;
+				logAi->debug("BattleAI: %s -> %s %d from, %d curpos %d dist %d speed %d: %d %d %d",
+					VLC->creh->creatures.at(target.attackerState->creatureId())->identifier.c_str(),
+					VLC->creh->creatures.at(target.affectedUnits[0]->creatureId())->identifier.c_str(),
+					(int)target.affectedUnits.size(), (int)target.from, (int)bestAttack.attack.attacker->getPosition().hex,
+					(int)bestAttack.attack.chargedFields, (int)bestAttack.attack.attacker->Speed(0, true),
+					(int)target.damageDealt, (int)target.damageReceived, (int)target.attackValue()
+				);
+
+				return BattleAction::makeMeleeAttack(stack,	bestAttack.attack.defender->getPosition(), bestAttack.from);
+			}
 		}
 		else if(bestSpellcast.is_initialized())
 		{

+ 32 - 8
AI/BattleAI/PotentialTargets.cpp

@@ -53,7 +53,7 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 			if(hex.isValid() && !shooting)
 				bai.chargedFields = reachability.distances[hex];
 
-			return AttackPossibility::evaluate(bai, hex);
+			return AttackPossibility::evaluate(bai, hex, state);
 		};
 
 		if(forceTarget)
@@ -69,21 +69,45 @@ PotentialTargets::PotentialTargets(const battle::Unit * attacker, const Hypothet
 		}
 		else
 		{
-			for(BattleHex hex : avHexes)
-				if(CStack::isMeleeAttackPossible(attackerInfo, defender, hex))
-					possibleAttacks.push_back(GenerateAttackInfo(false, hex));
+			for(BattleHex hex : avHexes) {
+				if(!CStack::isMeleeAttackPossible(attackerInfo, defender, hex))
+					continue;
+
+				auto bai = GenerateAttackInfo(false, hex);
+				if(!bai.affectedUnits.empty())
+					possibleAttacks.push_back(bai);
+			}
 
 			if(!vstd::contains_if(possibleAttacks, [=](const AttackPossibility & pa) { return pa.attack.defender->unitId() == defender->unitId(); }))
 				unreachableEnemies.push_back(defender);
 		}
 	}
+
+	boost::sort(possibleAttacks, [](const AttackPossibility & lhs, const AttackPossibility & rhs) -> bool
+	{
+		if(lhs.collateralDamage < rhs.collateralDamage)
+			return false;
+		if(lhs.collateralDamage > rhs.collateralDamage)
+			return true;
+		return (lhs.damageDealt + lhs.shootersBlockedDmg + lhs.damageReceived > rhs.damageDealt + rhs.shootersBlockedDmg + rhs.damageReceived);
+	});
+
+	if (!possibleAttacks.empty())
+	{
+		auto &bestAp = possibleAttacks[0];
+
+		logGlobal->info("Battle AI best: %s -> %s at %d from %d, affects %d units: %d %d %d %s",
+			VLC->creh->creatures.at(bestAp.attackerState->creatureId())->identifier.c_str(),
+			VLC->creh->creatures.at(state->battleGetUnitByPos(bestAp.dest)->creatureId())->identifier.c_str(),
+			(int)bestAp.dest, (int)bestAp.from, (int)bestAp.affectedUnits.size(),
+			(int)bestAp.damageDealt, (int)bestAp.damageReceived, (int)bestAp.collateralDamage, (int)bestAp.shootersBlockedDmg);
+	}
 }
 
-int PotentialTargets::bestActionValue() const
+int64_t PotentialTargets::bestActionValue() const
 {
 	if(possibleAttacks.empty())
 		return 0;
-
 	return bestAction().attackValue();
 }
 
@@ -91,6 +115,6 @@ AttackPossibility PotentialTargets::bestAction() const
 {
 	if(possibleAttacks.empty())
 		throw std::runtime_error("No best action, since we don't have any actions");
-
-	return *vstd::maxElementByFun(possibleAttacks, [](const AttackPossibility &ap) { return ap.attackValue(); } );
+	return possibleAttacks[0];
+	//return *vstd::maxElementByFun(possibleAttacks, [](const AttackPossibility &ap) { return ap.attackValue(); } );
 }

+ 1 - 1
AI/BattleAI/PotentialTargets.h

@@ -20,5 +20,5 @@ public:
 	PotentialTargets(const battle::Unit * attacker, const HypotheticBattle * state);
 
 	AttackPossibility bestAction() const;
-	int bestActionValue() const;
+	int64_t bestActionValue() const;
 };

+ 42 - 11
lib/CStack.cpp

@@ -252,22 +252,53 @@ void CStack::prepareAttacked(BattleStackAttacked & bsa, vstd::RNG & rand, std::s
 	bsa.newState.operation = UnitChanges::EOperation::RESET_STATE;
 }
 
-bool CStack::isMeleeAttackPossible(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos, BattleHex defenderPos)
+std::vector<BattleHex> CStack::meleeAttackHexes(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos, BattleHex defenderPos)
 {
-	if(!attackerPos.isValid())
+	int mask = 0;
+	std::vector<BattleHex> res;
+
+	if (!attackerPos.isValid())
 		attackerPos = attacker->getPosition();
-	if(!defenderPos.isValid())
+	if (!defenderPos.isValid())
 		defenderPos = defender->getPosition();
 
-	return
-		(BattleHex::mutualPosition(attackerPos, defenderPos) >= 0)//front <=> front
-		|| (attacker->doubleWide()//back <=> front
-			&& BattleHex::mutualPosition(attackerPos + (attacker->unitSide() == BattleSide::ATTACKER ? -1 : 1), defenderPos) >= 0)
-		|| (defender->doubleWide()//front <=> back
-			&& BattleHex::mutualPosition(attackerPos, defenderPos + (defender->unitSide() == BattleSide::ATTACKER ? -1 : 1)) >= 0)
-		|| (defender->doubleWide() && attacker->doubleWide()//back <=> back
-			&& BattleHex::mutualPosition(attackerPos + (attacker->unitSide() == BattleSide::ATTACKER ? -1 : 1), defenderPos + (defender->unitSide() == BattleSide::ATTACKER ? -1 : 1)) >= 0);
+	BattleHex otherAttackerPos = attackerPos + (attacker->unitSide() == BattleSide::ATTACKER ? -1 : 1);
+	BattleHex otherDefenderPos = defenderPos + (defender->unitSide() == BattleSide::ATTACKER ? -1 : 1);
 
+	if(BattleHex::mutualPosition(attackerPos, defenderPos) >= 0) { //front <=> front
+		if((mask & 1) == 0) {
+			mask |= 1;
+			res.push_back(defenderPos);
+		}
+	}
+	if (attacker->doubleWide() //back <=> front
+		&& BattleHex::mutualPosition(otherAttackerPos, defenderPos) >= 0) {
+		if((mask & 1) == 0) {
+			mask |= 1;
+			res.push_back(defenderPos);
+		}
+	}
+	if (defender->doubleWide()//front <=> back
+		&& BattleHex::mutualPosition(attackerPos, otherDefenderPos) >= 0) {
+		if((mask & 2) == 0) {
+			mask |= 2;
+			res.push_back(otherDefenderPos);
+		}
+	}
+	if (defender->doubleWide() && attacker->doubleWide()//back <=> back
+		&& BattleHex::mutualPosition(otherAttackerPos, otherDefenderPos) >= 0) {
+		if((mask & 2) == 0) {
+			mask |= 2;
+			res.push_back(otherDefenderPos);
+		}
+	}
+
+	return res;
+}
+
+bool CStack::isMeleeAttackPossible(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos, BattleHex defenderPos)
+{
+	return !meleeAttackHexes(attacker, defender, attackerPos, defenderPos).empty();
 }
 
 std::string CStack::getName() const

+ 1 - 0
lib/CStack.h

@@ -55,6 +55,7 @@ public:
 	std::vector<si32> activeSpells() const; //returns vector of active spell IDs sorted by time of cast
 	const CGHeroInstance * getMyHero() const; //if stack belongs to hero (directly or was by him summoned) returns hero, nullptr otherwise
 
+	static std::vector<BattleHex> meleeAttackHexes(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos = BattleHex::INVALID, BattleHex defenderPos = BattleHex::INVALID);
 	static bool isMeleeAttackPossible(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerPos = BattleHex::INVALID, BattleHex defenderPos = BattleHex::INVALID);
 
 	BattleHex::EDir destShiftDir() const;

+ 3 - 0
lib/battle/BattleAttackInfo.cpp

@@ -19,6 +19,8 @@ BattleAttackInfo::BattleAttackInfo(const battle::Unit * Attacker, const battle::
 	chargedFields = 0;
 	additiveBonus = 0.0;
 	multBonus = 1.0;
+	attackerPos = BattleHex::INVALID;
+	defenderPos = BattleHex::INVALID;
 }
 
 BattleAttackInfo BattleAttackInfo::reverse() const
@@ -26,6 +28,7 @@ BattleAttackInfo BattleAttackInfo::reverse() const
 	BattleAttackInfo ret = *this;
 
 	std::swap(ret.attacker, ret.defender);
+	std::swap(ret.defenderPos, ret.attackerPos);
 
 	ret.shooting = false;
 	ret.chargedFields = 0;

+ 5 - 0
lib/battle/BattleAttackInfo.h

@@ -15,11 +15,16 @@ namespace battle
 	class CUnitState;
 }
 
+#include "BattleHex.h"
+
 struct DLL_LINKAGE BattleAttackInfo
 {
 	const battle::Unit * attacker;
 	const battle::Unit * defender;
 
+	BattleHex attackerPos;
+	BattleHex defenderPos;
+
 	bool shooting;
 	int chargedFields;
 

+ 94 - 47
lib/battle/CBattleInfoCallback.cpp

@@ -667,37 +667,45 @@ bool CBattleInfoCallback::battleCanAttack(const CStack * stack, const CStack * t
 	return target->alive();
 }
 
-bool CBattleInfoCallback::battleCanShoot(const battle::Unit * attacker, BattleHex dest) const
+bool CBattleInfoCallback::battleCanShoot(const battle::Unit * attacker) const
 {
 	RETURN_IF_NOT_BATTLE(false);
 
-	if(battleTacticDist()) //no shooting during tactics
+	if (battleTacticDist()) //no shooting during tactics
 		return false;
 
-	const battle::Unit * defender = battleGetUnitByPos(dest);
-
-	if(!attacker || !defender)
+	if (!attacker)
+		return false;
+	if (attacker->creatureIndex() == CreatureID::CATAPULT) //catapult cannot attack creatures
 		return false;
 
 	//forgetfulness
 	TBonusListPtr forgetfulList = attacker->getBonuses(Selector::type(Bonus::FORGETFULL));
-	if(!forgetfulList->empty())
+	if (!forgetfulList->empty())
 	{
 		int forgetful = forgetfulList->valOfBonuses(Selector::type(Bonus::FORGETFULL));
 
 		//advanced+ level
-		if(forgetful > 1)
+		if (forgetful > 1)
 			return false;
 	}
 
-	if(attacker->creatureIndex() == CreatureID::CATAPULT && defender) //catapult cannot attack creatures
+	return attacker->canShoot()	&& (!battleIsUnitBlocked(attacker)
+			|| attacker->hasBonusOfType(Bonus::FREE_SHOOTING));
+}
+
+bool CBattleInfoCallback::battleCanShoot(const battle::Unit * attacker, BattleHex dest) const
+{
+	RETURN_IF_NOT_BATTLE(false);
+
+	const battle::Unit * defender = battleGetUnitByPos(dest);
+	if(!attacker || !defender)
 		return false;
 
-	return attacker->canShoot()
-		&& battleMatchOwner(attacker, defender)
-		&& defender->alive()
-		&& (!battleIsUnitBlocked(attacker)
-		|| attacker->hasBonusOfType(Bonus::FREE_SHOOTING));
+	if(battleMatchOwner(attacker, defender) && defender->alive())
+		return battleCanShoot(attacker);
+
+	return false;
 }
 
 TDmgRange CBattleInfoCallback::calculateDmgRange(const BattleAttackInfo & info) const
@@ -897,8 +905,11 @@ TDmgRange CBattleInfoCallback::calculateDmgRange(const BattleAttackInfo & info)
 	if(info.shooting)
 	{
 		//wall / distance penalty + advanced air shield
-		const bool distPenalty = battleHasDistancePenalty(attackerBonuses, info.attacker->getPosition(), info.defender->getPosition());
-		const bool obstaclePenalty = battleHasWallPenalty(attackerBonuses, info.attacker->getPosition(), info.defender->getPosition());
+		BattleHex attackerPos = info.attackerPos.isValid() ? info.attackerPos : info.attacker->getPosition();
+		BattleHex defenderPos = info.defenderPos.isValid() ? info.defenderPos : info.defender->getPosition();
+
+		const bool distPenalty = battleHasDistancePenalty(attackerBonuses, attackerPos, defenderPos);
+		const bool obstaclePenalty = battleHasWallPenalty(attackerBonuses, attackerPos, defenderPos);
 
 		if(distPenalty || defenderBonuses->hasBonus(isAdvancedAirShield, cachingStrAdvAirShield))
 			multBonus *= 0.5;
@@ -1340,11 +1351,11 @@ ReachabilityInfo CBattleInfoCallback::getFlyingReachability(const ReachabilityIn
 	return ret;
 }
 
-AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes (const CStack* attacker, BattleHex destinationTile, BattleHex attackerPos) const
+AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes (const  battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const
 {
 	//does not return hex attacked directly
 	//TODO: apply rotation to two-hex attackers
-	bool isAttacker = attacker->side == BattleSide::ATTACKER;
+	bool isAttacker = attacker->unitSide() == BattleSide::ATTACKER;
 
 	AttackableTiles at;
 	RETURN_IF_NOT_BATTLE(at);
@@ -1369,8 +1380,8 @@ AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes (const CStack
 		{
 			if((BattleHex::mutualPosition(tile, destinationTile) > -1 && BattleHex::mutualPosition(tile, hex) > -1)) //adjacent both to attacker's head and attacked tile
 			{
-				const CStack * st = battleGetStackByPos(tile, true);
-				if(st && st->owner != attacker->owner) //only hostile stacks - does it work well with Berserk?
+				auto st = battleGetUnitByPos(tile, true);
+				if(st && battleMatchOwner(st, attacker)) //only hostile stacks - does it work well with Berserk?
 				{
 					at.hostileCreaturePositions.insert(tile);
 				}
@@ -1391,45 +1402,50 @@ AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes (const CStack
 		for(BattleHex tile : hexes)
 		{
 			//friendly stacks can also be damaged by Dragon Breath
-			if(battleGetStackByPos(tile, true))
+			auto st = battleGetUnitByPos(tile, true);
+			if(st && st != attacker)
 			{
-				if(battleGetStackByPos(tile, true) != attacker)
-					at.friendlyCreaturePositions.insert(tile);
+				at.friendlyCreaturePositions.insert(tile);
 			}
 		}
 	}
-	else if(attacker->hasBonusOfType(Bonus::TWO_HEX_ATTACK_BREATH) && BattleHex::mutualPosition(destinationTile, hex) > -1) //only adjacent hexes are subject of dragon breath calculation
+	else if(attacker->hasBonusOfType(Bonus::TWO_HEX_ATTACK_BREATH))
 	{
-		std::vector<BattleHex> hexes; //only one, in fact
-		int pseudoVector = destinationTile.hex - hex;
-		switch(pseudoVector)
+		int pos = BattleHex::mutualPosition(destinationTile, hex);
+		if (pos > -1) //only adjacent hexes are subject of dragon breath calculation
 		{
-		case 1:
-		case -1:
-			BattleHex::checkAndPush(destinationTile.hex + pseudoVector, hexes);
-			break;
-		case WN: //17 //left-down or right-down
-		case -WN: //-17 //left-up or right-up
-		case WN + 1: //18 //right-down
-		case -WN + 1: //-16 //right-up
-			BattleHex::checkAndPush(destinationTile.hex + pseudoVector + (((hex / WN) % 2) ? 1 : -1), hexes);
-			break;
-		case WN - 1: //16 //left-down
-		case -WN - 1: //-18 //left-up
-			BattleHex::checkAndPush(destinationTile.hex + pseudoVector + (((hex / WN) % 2) ? 1 : 0), hexes);
-			break;
-		}
-		for(BattleHex tile : hexes)
-		{
-			//friendly stacks can also be damaged by Dragon Breath
-			if(battleGetStackByPos(tile, true))
-				at.friendlyCreaturePositions.insert(tile);
+			std::vector<BattleHex> hexes; //only one, in fact
+			int pseudoVector = destinationTile.hex - hex;
+			switch (pseudoVector)
+			{
+			case 1:
+			case -1:
+				BattleHex::checkAndPush(destinationTile.hex + pseudoVector, hexes);
+				break;
+			case WN: //17 //left-down or right-down
+			case -WN: //-17 //left-up or right-up
+			case WN + 1: //18 //right-down
+			case -WN + 1: //-16 //right-up
+				BattleHex::checkAndPush(destinationTile.hex + pseudoVector + (((hex / WN) % 2) ? 1 : -1), hexes);
+				break;
+			case WN - 1: //16 //left-down
+			case -WN - 1: //-18 //left-up
+				BattleHex::checkAndPush(destinationTile.hex + pseudoVector + (((hex / WN) % 2) ? 1 : 0), hexes);
+				break;
+			}
+			for (BattleHex tile : hexes)
+			{
+				//friendly stacks can also be damaged by Dragon Breath
+				auto st = battleGetUnitByPos(tile, true);
+				if (st != nullptr)
+					at.friendlyCreaturePositions.insert(tile);
+			}
 		}
 	}
 	return at;
 }
 
-AttackableTiles CBattleInfoCallback::getPotentiallyShootableHexes(const CStack * attacker, BattleHex destinationTile, BattleHex attackerPos) const
+AttackableTiles CBattleInfoCallback::getPotentiallyShootableHexes(const  battle::Unit * attacker, BattleHex destinationTile, BattleHex attackerPos) const
 {
 	//does not return hex attacked directly
 	AttackableTiles at;
@@ -1445,6 +1461,37 @@ AttackableTiles CBattleInfoCallback::getPotentiallyShootableHexes(const CStack *
 	return at;
 }
 
+std::vector<const battle::Unit*> CBattleInfoCallback::getAttackedBattleUnits(const battle::Unit* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos) const
+{
+	std::vector<const battle::Unit*> units;
+	RETURN_IF_NOT_BATTLE(units);
+
+	AttackableTiles at;
+
+	if (rangedAttack)
+		at = getPotentiallyShootableHexes(attacker, destinationTile, attackerPos);
+	else
+		at = getPotentiallyAttackableHexes(attacker, destinationTile, attackerPos);
+
+	units = battleGetUnitsIf([=](const battle::Unit * unit)
+	{
+		if (unit->isGhost() || !unit->alive()) {
+			return false;
+		}
+		for (BattleHex hex : battle::Unit::getHexes(unit->getPosition(), unit->doubleWide(), unit->unitSide())) {
+			if (vstd::contains(at.hostileCreaturePositions, hex)) {
+				return true;
+			}
+			if (vstd::contains(at.friendlyCreaturePositions, hex)) {
+				return true;
+			}
+		}
+		return false;
+	});
+
+	return units;
+}
+
 std::set<const CStack*> CBattleInfoCallback::getAttackedCreatures(const CStack* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos) const
 {
 	std::set<const CStack*> attackedCres;

+ 4 - 2
lib/battle/CBattleInfoCallback.h

@@ -89,6 +89,7 @@ public:
 
 	bool battleCanAttack(const CStack * stack, const CStack * target, BattleHex dest) const; //determines if stack with given ID can attack target at the selected destination
 	bool battleCanShoot(const battle::Unit * attacker, BattleHex dest) const; //determines if stack with given ID shoot at the selected destination
+	bool battleCanShoot(const battle::Unit * attacker) const; //determines if stack with given ID shoot in principle
 	bool battleIsUnitBlocked(const battle::Unit * unit) const; //returns true if there is neighboring enemy stack
 	std::set<const battle::Unit *> battleAdjacentUnits(const battle::Unit * unit) const;
 
@@ -123,8 +124,9 @@ public:
 	bool isInTacticRange(BattleHex dest) const;
 	si8 battleGetTacticDist() const; //returns tactic distance for calling player or 0 if this player is not in tactic phase (for ALL_KNOWING actual distance for tactic side)
 
-	AttackableTiles getPotentiallyAttackableHexes(const CStack* attacker, BattleHex destinationTile, BattleHex attackerPos) const; //TODO: apply rotation to two-hex attacker
-	AttackableTiles getPotentiallyShootableHexes(const CStack* attacker, BattleHex destinationTile, BattleHex attackerPos) const;
+	AttackableTiles getPotentiallyAttackableHexes(const  battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const; //TODO: apply rotation to two-hex attacker
+	AttackableTiles getPotentiallyShootableHexes(const  battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const;
+	std::vector<const battle::Unit *> getAttackedBattleUnits(const battle::Unit* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos = BattleHex::INVALID) const; //calculates range of multi-hex attacks
 	std::set<const CStack*> getAttackedCreatures(const CStack* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos = BattleHex::INVALID) const; //calculates range of multi-hex attacks
 	bool isToReverse(BattleHex hexFrom, BattleHex hexTo, bool curDir /*if true, creature is in attacker's direction*/, bool toDoubleWide, bool toDir) const; //determines if creature should be reversed (it stands on hexFrom and should 'see' hexTo)
 	bool isToReverseHlp(BattleHex hexFrom, BattleHex hexTo, bool curDir) const; //helper for isToReverse