浏览代码

Merge pull request #4323 from vcmi/battle-ai-fixes

Battle ai fixes
Andrii Danylchenko 1 年之前
父节点
当前提交
8e79263b21

+ 86 - 33
AI/BattleAI/AttackPossibility.cpp

@@ -93,6 +93,8 @@ int64_t DamageCache::getOriginalDamage(const battle::Unit * attacker, const batt
 AttackPossibility::AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack)
 	: from(from), dest(dest), attack(attack)
 {
+	this->attack.attackerPos = from;
+	this->attack.defenderPos = dest;
 }
 
 float AttackPossibility::damageDiff() const
@@ -261,63 +263,105 @@ AttackPossibility AttackPossibility::evaluate(
 		if (!attackInfo.shooting)
 			ap.attackerState->setPosition(hex);
 
-		std::vector<const battle::Unit*> units;
+		std::vector<const battle::Unit *> defenderUnits;
+		std::vector<const battle::Unit *> retaliatedUnits = {attacker};
+		std::vector<const battle::Unit *> affectedUnits;
 
 		if (attackInfo.shooting)
-			units = state->getAttackedBattleUnits(attacker, defHex, true, BattleHex::INVALID);
+			defenderUnits = state->getAttackedBattleUnits(attacker, defender, defHex, true, hex, defender->getPosition());
 		else
-			units = state->getAttackedBattleUnits(attacker, defHex, false, hex);
-
-		// ensure the defender is also affected
-		bool addDefender = true;
-		for(auto unit : units)
 		{
-			if (unit->unitId() == defender->unitId())
+			defenderUnits = state->getAttackedBattleUnits(attacker, defender, defHex, false, hex, defender->getPosition());
+			retaliatedUnits = state->getAttackedBattleUnits(defender, attacker, hex, false, defender->getPosition(), hex);
+
+			// attacker can not melle-attack itself but still can hit that place where it was before moving
+			vstd::erase_if(defenderUnits, [attacker](const battle::Unit * u) -> bool { return u->unitId() == attacker->unitId(); });
+
+			if(!vstd::contains_if(retaliatedUnits, [attacker](const battle::Unit * u) -> bool { return u->unitId() == attacker->unitId(); }))
 			{
-				addDefender = false;
-				break;
+				retaliatedUnits.push_back(attacker);
 			}
 		}
 
-		if(addDefender)
-			units.push_back(defender);
+		// ensure the defender is also affected
+		if(!vstd::contains_if(defenderUnits, [defender](const battle::Unit * u) -> bool { return u->unitId() == defender->unitId(); }))
+		{
+			defenderUnits.push_back(defender);
+		}
+
+		affectedUnits = defenderUnits;
+		vstd::concatenate(affectedUnits, retaliatedUnits);
+
+		logAi->trace("Attacked battle units count %d, %d->%d", affectedUnits.size(), hex.hex, defHex.hex);
 
-		for(auto u : units)
+		std::map<uint32_t, std::shared_ptr<battle::CUnitState>> defenderStates;
+
+		for(auto u : affectedUnits)
 		{
-			if(!ap.attackerState->alive())
-				break;
+			if(u->unitId() == attacker->unitId())
+				continue;
 
 			auto defenderState = u->acquireState();
+
 			ap.affectedUnits.push_back(defenderState);
+			defenderStates[u->unitId()] = defenderState;
+		}
+
+		for(int i = 0; i < totalAttacks; i++)
+		{
+			if(!ap.attackerState->alive() || !defenderStates[defender->unitId()]->alive())
+				break;
 
-			for(int i = 0; i < totalAttacks; i++)
+			for(auto u : defenderUnits)
 			{
+				auto defenderState = defenderStates.at(u->unitId());
+
 				int64_t damageDealt;
-				int64_t damageReceived;
 				float defenderDamageReduce;
 				float attackerDamageReduce;
 
 				DamageEstimation retaliation;
 				auto attackDmg = state->battleEstimateDamage(ap.attack, &retaliation);
 
-				vstd::amin(attackDmg.damage.min, defenderState->getAvailableHealth());
-				vstd::amin(attackDmg.damage.max, defenderState->getAvailableHealth());
-
-				vstd::amin(retaliation.damage.min, ap.attackerState->getAvailableHealth());
-				vstd::amin(retaliation.damage.max, ap.attackerState->getAvailableHealth());
-
 				damageDealt = averageDmg(attackDmg.damage);
-				defenderDamageReduce = calculateDamageReduce(attacker, defender, damageDealt, damageCache, state);
+				vstd::amin(damageDealt, defenderState->getAvailableHealth());
+
+				defenderDamageReduce = calculateDamageReduce(attacker, u, damageDealt, damageCache, state);
 				ap.attackerState->afterAttack(attackInfo.shooting, false);
 
 				//FIXME: use ranged retaliation
-				damageReceived = 0;
 				attackerDamageReduce = 0;
 
-				if (!attackInfo.shooting && defenderState->ableToRetaliate() && !counterAttacksBlocked)
+				if (!attackInfo.shooting && u->unitId() == defender->unitId() && defenderState->ableToRetaliate() && !counterAttacksBlocked)
 				{
-					damageReceived = averageDmg(retaliation.damage);
-					attackerDamageReduce = calculateDamageReduce(defender, attacker, damageReceived, damageCache, state);
+					for(auto retaliated : retaliatedUnits)
+					{
+						if(retaliated->unitId() == attacker->unitId())
+						{
+							int64_t damageReceived = averageDmg(retaliation.damage);
+
+							vstd::amin(damageReceived, ap.attackerState->getAvailableHealth());
+
+							attackerDamageReduce = calculateDamageReduce(defender, retaliated, damageReceived, damageCache, state);
+							ap.attackerState->damage(damageReceived);
+						}
+						else
+						{
+							auto retaliationCollateral = state->battleEstimateDamage(defender, retaliated, 0);
+							int64_t damageReceived = averageDmg(retaliationCollateral.damage);
+
+							vstd::amin(damageReceived, retaliated->getAvailableHealth());
+
+							if(defender->unitSide() == retaliated->unitSide())
+								defenderDamageReduce += calculateDamageReduce(defender, retaliated, damageReceived, damageCache, state);
+							else
+								ap.collateralDamageReduce += calculateDamageReduce(defender, retaliated, damageReceived, damageCache, state);
+
+							defenderStates.at(retaliated->unitId())->damage(damageReceived);
+						}
+						
+					}
+
 					defenderState->afterAttack(attackInfo.shooting, true);
 				}
 
@@ -331,21 +375,30 @@ AttackPossibility AttackPossibility::evaluate(
 				if(attackerSide == u->unitSide())
 					ap.collateralDamageReduce += defenderDamageReduce;
 
-				if(u->unitId() == defender->unitId() || 
-					(!attackInfo.shooting && CStack::isMeleeAttackPossible(u, attacker, hex)))
+				if(u->unitId() == defender->unitId()
+					|| (!attackInfo.shooting && CStack::isMeleeAttackPossible(u, attacker, hex)))
 				{
 					//FIXME: handle RANGED_RETALIATION ?
 					ap.attackerDamageReduce += attackerDamageReduce;
 				}
 
-				ap.attackerState->damage(damageReceived);
 				defenderState->damage(damageDealt);
 
-				if (!ap.attackerState->alive() || !defenderState->alive())
-					break;
+				if(u->unitId() == defender->unitId())
+				{
+					ap.defenderDead = !defenderState->alive();
+				}
 			}
 		}
 
+#if BATTLE_TRACE_LEVEL>=2
+		logAi->trace("BattleAI AP: %s -> %s at %d from %d, affects %d units: d:%lld a:%lld c:%lld s:%lld",
+			attackInfo.attacker->unitType()->getJsonKey(),
+			attackInfo.defender->unitType()->getJsonKey(),
+			(int)ap.dest, (int)ap.from, (int)ap.affectedUnits.size(),
+			ap.defenderDamageReduce, ap.attackerDamageReduce, ap.collateralDamageReduce, ap.shootersBlockedDmg);
+#endif
+
 		if(!bestAp.dest.isValid() || ap.attackValue() > bestAp.attackValue())
 			bestAp = ap;
 	}

+ 1 - 0
AI/BattleAI/AttackPossibility.h

@@ -49,6 +49,7 @@ public:
 	float attackerDamageReduce = 0; //usually by counter-attack
 	float collateralDamageReduce = 0; // friendly fire (usually by two-hex attacks)
 	int64_t shootersBlockedDmg = 0;
+	bool defenderDead = false;
 
 	AttackPossibility(BattleHex from, BattleHex dest, const BattleAttackInfo & attack_);
 

+ 1 - 1
AI/BattleAI/BattleEvaluator.cpp

@@ -189,7 +189,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 					else
 					{
 						activeActionMade = true;
-						return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from);
+						return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defenderPos, bestAttack.from);
 					}
 				}
 			}

+ 51 - 59
AI/BattleAI/BattleExchangeVariant.cpp

@@ -30,100 +30,89 @@ float BattleExchangeVariant::trackAttack(
 {
 	auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
 
-	const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
-	static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
-	const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
-
-	float attackValue = 0;
+	float attackValue = ap.attackValue();
 	auto affectedUnits = ap.affectedUnits;
 
+	dpsScore.ourDamageReduce += ap.attackerDamageReduce + ap.collateralDamageReduce;
+	dpsScore.enemyDamageReduce += ap.defenderDamageReduce + ap.shootersBlockedDmg;
+	attackerValue[attacker->unitId()].value = attackValue;
+
 	affectedUnits.push_back(ap.attackerState);
 
 	for(auto affectedUnit : affectedUnits)
 	{
 		auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
+		auto damageDealt = unitToUpdate->getTotalHealth() - affectedUnit->getTotalHealth();
+
+		if(damageDealt > 0)
+		{
+			unitToUpdate->damage(damageDealt);
+		}
 
 		if(unitToUpdate->unitSide() == attacker->unitSide())
 		{
 			if(unitToUpdate->unitId() == attacker->unitId())
 			{
-				auto defender = hb->getForUpdate(ap.attack.defender->unitId());
-
-				if(!defender->alive() || counterAttacksBlocked || ap.attack.shooting || !defender->ableToRetaliate())
-					continue;
-
-				auto retaliationDamage = damageCache.getDamage(defender.get(), unitToUpdate.get(), hb);
-				auto attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), unitToUpdate.get(), retaliationDamage, damageCache, hb);
-
-				attackValue -= attackerDamageReduce;
-				dpsScore.ourDamageReduce += attackerDamageReduce;
-				attackerValue[unitToUpdate->unitId()].isRetaliated = true;
-
-				unitToUpdate->damage(retaliationDamage);
-				defender->afterAttack(false, true);
+				unitToUpdate->afterAttack(ap.attack.shooting, false);
 
 #if BATTLE_TRACE_LEVEL>=1
 				logAi->trace(
-					"%s -> %s, ap retaliation, %s, dps: %2f, score: %2f",
-					defender->getDescription(),
-					unitToUpdate->getDescription(),
+					"%s -> %s, ap retaliation, %s, dps: %lld",
+					ap.attack.defender->getDescription(),
+					ap.attack.attacker->getDescription(),
 					ap.attack.shooting ? "shot" : "mellee",
-					retaliationDamage,
-					attackerDamageReduce);
+					damageDealt);
 #endif
 			}
 			else
 			{
-				auto collateralDamage = damageCache.getDamage(attacker.get(), unitToUpdate.get(), hb);
-				auto collateralDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), collateralDamage, damageCache, hb);
-
-				attackValue -= collateralDamageReduce;
-				dpsScore.ourDamageReduce += collateralDamageReduce;
-
-				unitToUpdate->damage(collateralDamage);
-
 #if BATTLE_TRACE_LEVEL>=1
 				logAi->trace(
-					"%s -> %s, ap collateral, %s, dps: %2f, score: %2f",
-					attacker->getDescription(),
+					"%s, ap collateral, dps: %lld",
 					unitToUpdate->getDescription(),
-					ap.attack.shooting ? "shot" : "mellee",
-					collateralDamage,
-					collateralDamageReduce);
+					damageDealt);
 #endif
 			}
 		}
 		else
 		{
-			int64_t attackDamage = damageCache.getDamage(attacker.get(), unitToUpdate.get(), hb);
-			float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), attackDamage, damageCache, hb);
-
-			attackValue += defenderDamageReduce;
-			dpsScore.enemyDamageReduce += defenderDamageReduce;
-			attackerValue[attacker->unitId()].value += defenderDamageReduce;
-
-			unitToUpdate->damage(attackDamage);
+			if(unitToUpdate->unitId() == ap.attack.defender->unitId())
+			{
+				if(unitToUpdate->ableToRetaliate() && !affectedUnit->ableToRetaliate())
+				{
+					unitToUpdate->afterAttack(ap.attack.shooting, true);
+				}
 
 #if BATTLE_TRACE_LEVEL>=1
-			logAi->trace(
-				"%s -> %s, ap attack, %s, dps: %2f, score: %2f",
-				attacker->getDescription(),
-				unitToUpdate->getDescription(),
-				ap.attack.shooting ? "shot" : "mellee",
-				attackDamage,
-				defenderDamageReduce);
+				logAi->trace(
+					"%s -> %s, ap attack, %s, dps: %lld",
+					attacker->getDescription(),
+					ap.attack.defender->getDescription(),
+					ap.attack.shooting ? "shot" : "mellee",
+					damageDealt);
 #endif
+			}
+			else
+			{
+#if BATTLE_TRACE_LEVEL>=1
+				logAi->trace(
+					"%s, ap enemy collateral, dps: %lld",
+					unitToUpdate->getDescription(),
+					damageDealt);
+#endif
+			}
 		}
 	}
 
 #if BATTLE_TRACE_LEVEL >= 1
-	logAi->trace("ap shooters blocking: %lld", ap.shootersBlockedDmg);
+	logAi->trace(
+		"ap score: our: %2f, enemy: %2f, collateral: %2f, blocked: %2f",
+		ap.attackerDamageReduce,
+		ap.defenderDamageReduce,
+		ap.collateralDamageReduce,
+		ap.shootersBlockedDmg);
 #endif
 
-	attackValue += ap.shootersBlockedDmg;
-	dpsScore.enemyDamageReduce += ap.shootersBlockedDmg;
-	attacker->afterAttack(ap.attack.shooting, false);
-
 	return attackValue;
 }
 
@@ -230,6 +219,7 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 		auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
 
+		hbWaited->resetActiveUnit();
 		hbWaited->getForUpdate(activeStack->unitId())->waiting = true;
 		hbWaited->getForUpdate(activeStack->unitId())->waitedThisTurn = true;
 
@@ -259,6 +249,7 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 	updateReachabilityMap(hb);
 
 	if(result.bestAttack.attack.shooting
+		&& !result.bestAttack.defenderDead
 		&& !activeStack->waited()
 		&& hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
 	{
@@ -269,8 +260,9 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 	for(auto & ap : targets.possibleAttacks)
 	{
 		float score = evaluateExchange(ap, 0, targets, damageCache, hb);
+		bool sameScoreButWaited = vstd::isAlmostEqual(score, result.score) && result.wait;
 
-		if(score > result.score || (vstd::isAlmostEqual(score, result.score) && result.wait))
+		if(score > result.score || sameScoreButWaited)
 		{
 			result.score = score;
 			result.bestAttack = ap;
@@ -739,7 +731,7 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
 {
 	std::vector<const battle::Unit *> result;
 
-	for(int i = 0; i < turnOrder.size(); i++, turn++)
+	for(int i = 0; i < turnOrder.size(); i++)
 	{
 		auto & turnQueue = turnOrder[i];
 		HypotheticBattle turnBattle(env.get(), cb);

+ 1 - 1
AI/BattleAI/BattleExchangeVariant.h

@@ -148,7 +148,7 @@ public:
 		std::shared_ptr<CBattleInfoCallback> cb,
 		std::shared_ptr<Environment> env,
 		float strengthRatio): cb(cb), env(env) {
-		negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio;
+		negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio * strengthRatio;
 	}
 
 	EvaluationResult findBestTarget(

+ 5 - 0
AI/BattleAI/StackWithBonuses.h

@@ -164,6 +164,11 @@ public:
 
 	int64_t getTreeVersion() const;
 
+	void resetActiveUnit()
+	{
+		activeUnitId = -1;
+	}
+
 #if SCRIPTING_ENABLED
 	scripting::Pool * getContextPool() const override;
 #endif

+ 12 - 12
client/render/CAnimation.cpp

@@ -161,13 +161,13 @@ void CAnimation::verticalFlip()
 
 void CAnimation::horizontalFlip(size_t frame, size_t group)
 {
-	try
+	auto i1 = images.find(group);
+	if(i1 != images.end())
 	{
-		images.at(group).at(frame) = nullptr;
-	}
-	catch (const std::out_of_range &)
-	{
-		// ignore - image not loaded
+		auto i2 = i1->second.find(frame);
+
+		if(i2 != i1->second.end())
+			i2->second = nullptr;
 	}
 
 	auto locator = getImageLocator(frame, group);
@@ -177,13 +177,13 @@ void CAnimation::horizontalFlip(size_t frame, size_t group)
 
 void CAnimation::verticalFlip(size_t frame, size_t group)
 {
-	try
+	auto i1 = images.find(group);
+	if(i1 != images.end())
 	{
-		images.at(group).at(frame) = nullptr;
-	}
-	catch (const std::out_of_range &)
-	{
-		// ignore - image not loaded
+		auto i2 = i1->second.find(frame);
+
+		if(i2 != i1->second.end())
+			i2->second = nullptr;
 	}
 
 	auto locator = getImageLocator(frame, group);

+ 70 - 20
lib/battle/CBattleInfoCallback.cpp

@@ -1248,19 +1248,40 @@ ReachabilityInfo CBattleInfoCallback::getFlyingReachability(const ReachabilityIn
 	return ret;
 }
 
-AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes(const battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const
+AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes(
+	const battle::Unit * attacker,
+	BattleHex destinationTile,
+	BattleHex attackerPos) const
+{
+	const auto * defender = battleGetUnitByPos(destinationTile, true);
+
+	if(!defender)
+		return AttackableTiles(); // can't attack thin air
+
+	return getPotentiallyAttackableHexes(
+		attacker,
+		defender,
+		destinationTile,
+		attackerPos,
+		defender->getPosition());
+}
+
+AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes(
+	const battle::Unit* attacker,
+	const battle::Unit * defender,
+	BattleHex destinationTile,
+	BattleHex attackerPos,
+	BattleHex defenderPos) const
 {
 	//does not return hex attacked directly
 	AttackableTiles at;
 	RETURN_IF_NOT_BATTLE(at);
 
 	BattleHex attackOriginHex = (attackerPos != BattleHex::INVALID) ? attackerPos : attacker->getPosition(); //real or hypothetical (cursor) position
-
-	const auto * defender = battleGetUnitByPos(destinationTile, true);
-	if (!defender)
-		return at; // can't attack thin air
-
-	bool reverse = isToReverse(attacker, defender);
+	
+	defenderPos = (defenderPos != BattleHex::INVALID) ? defenderPos : defender->getPosition(); //real or hypothetical (cursor) position
+	
+	bool reverse = isToReverse(attacker, defender, attackerPos, defenderPos);
 	if(reverse && attacker->doubleWide())
 	{
 		attackOriginHex = attacker->occupiedHex(attackOriginHex); //the other hex stack stands on
@@ -1304,19 +1325,26 @@ AttackableTiles CBattleInfoCallback::getPotentiallyAttackableHexes(const battle:
 	else if(attacker->hasBonusOfType(BonusType::TWO_HEX_ATTACK_BREATH))
 	{
 		auto direction = BattleHex::mutualPosition(attackOriginHex, destinationTile);
+		
+		if(direction == BattleHex::NONE
+			&& defender->doubleWide()
+			&& attacker->doubleWide()
+			&& defenderPos == destinationTile)
+		{
+			direction = BattleHex::mutualPosition(attackOriginHex, defender->occupiedHex(defenderPos));
+		}
+
 		if(direction != BattleHex::NONE) //only adjacent hexes are subject of dragon breath calculation
 		{
 			BattleHex nextHex = destinationTile.cloneInDirection(direction, false);
 
 			if ( defender->doubleWide() )
 			{
-				auto secondHex = destinationTile == defender->getPosition() ?
-					defender->occupiedHex():
-					defender->getPosition();
+				auto secondHex = destinationTile == defenderPos ? defender->occupiedHex(defenderPos) : defenderPos;
 
 				// if targeted double-wide creature is attacked from above or below ( -> second hex is also adjacent to attack origin)
 				// then dragon breath should target tile on the opposite side of targeted creature
-				if (BattleHex::mutualPosition(attackOriginHex, secondHex) != BattleHex::NONE)
+				if(BattleHex::mutualPosition(attackOriginHex, secondHex) != BattleHex::NONE)
 					nextHex = secondHex.cloneInDirection(direction, false);
 			}
 
@@ -1348,17 +1376,29 @@ AttackableTiles CBattleInfoCallback::getPotentiallyShootableHexes(const battle::
 	return at;
 }
 
-std::vector<const battle::Unit*> CBattleInfoCallback::getAttackedBattleUnits(const battle::Unit* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos) const
+std::vector<const battle::Unit*> CBattleInfoCallback::getAttackedBattleUnits(
+	const battle::Unit * attacker,
+	const  battle::Unit * defender,
+	BattleHex destinationTile,
+	bool rangedAttack,
+	BattleHex attackerPos,
+	BattleHex defenderPos) const
 {
 	std::vector<const battle::Unit*> units;
 	RETURN_IF_NOT_BATTLE(units);
 
+	if(attackerPos == BattleHex::INVALID)
+		attackerPos = attacker->getPosition();
+
+	if(defenderPos == BattleHex::INVALID)
+		defenderPos = defender->getPosition();
+
 	AttackableTiles at;
 
 	if (rangedAttack)
 		at = getPotentiallyShootableHexes(attacker, destinationTile, attackerPos);
 	else
-		at = getPotentiallyAttackableHexes(attacker, destinationTile, attackerPos);
+		at = getPotentiallyAttackableHexes(attacker, defender, destinationTile, attackerPos, defenderPos);
 
 	units = battleGetUnitsIf([=](const battle::Unit * unit)
 	{
@@ -1384,7 +1424,7 @@ std::set<const CStack*> CBattleInfoCallback::getAttackedCreatures(const CStack*
 	RETURN_IF_NOT_BATTLE(attackedCres);
 
 	AttackableTiles at;
-
+	
 	if(rangedAttack)
 		at = getPotentiallyShootableHexes(attacker, destinationTile, attackerPos);
 	else
@@ -1423,10 +1463,13 @@ static bool isHexInFront(BattleHex hex, BattleHex testHex, BattleSide::Type side
 }
 
 //TODO: this should apply also to mechanics and cursor interface
-bool CBattleInfoCallback::isToReverse(const battle::Unit * attacker, const battle::Unit * defender) const
+bool CBattleInfoCallback::isToReverse(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerHex, BattleHex defenderHex) const
 {
-	BattleHex attackerHex = attacker->getPosition();
-	BattleHex defenderHex = defender->getPosition();
+	if(!defenderHex.isValid())
+		defenderHex = defender->getPosition();
+
+	if(!attackerHex.isValid())
+		attackerHex = attacker->getPosition();
 
 	if (attackerHex < 0 ) //turret
 		return false;
@@ -1434,15 +1477,22 @@ bool CBattleInfoCallback::isToReverse(const battle::Unit * attacker, const battl
 	if(isHexInFront(attackerHex, defenderHex, static_cast<BattleSide::Type>(attacker->unitSide())))
 		return false;
 
+	auto defenderOtherHex = defenderHex;
+	auto attackerOtherHex = defenderHex;
+
 	if (defender->doubleWide())
 	{
-		if(isHexInFront(attackerHex, defender->occupiedHex(), static_cast<BattleSide::Type>(attacker->unitSide())))
+		defenderOtherHex = battle::Unit::occupiedHex(defenderHex, true, defender->unitSide());
+
+		if(isHexInFront(attackerHex, defenderOtherHex, static_cast<BattleSide::Type>(attacker->unitSide())))
 			return false;
 	}
 
 	if (attacker->doubleWide())
 	{
-		if(isHexInFront(attacker->occupiedHex(), defenderHex, static_cast<BattleSide::Type>(attacker->unitSide())))
+		attackerOtherHex = battle::Unit::occupiedHex(attackerHex, true, attacker->unitSide());
+
+		if(isHexInFront(attackerOtherHex, defenderHex, static_cast<BattleSide::Type>(attacker->unitSide())))
 			return false;
 	}
 
@@ -1450,7 +1500,7 @@ bool CBattleInfoCallback::isToReverse(const battle::Unit * attacker, const battl
 	// but this is how H3 handles it which is important, e.g. for direction of dragon breath attacks
 	if (attacker->doubleWide() && defender->doubleWide())
 	{
-		if(isHexInFront(attacker->occupiedHex(), defender->occupiedHex(), static_cast<BattleSide::Type>(attacker->unitSide())))
+		if(isHexInFront(attackerOtherHex, defenderOtherHex, static_cast<BattleSide::Type>(attacker->unitSide())))
 			return false;
 	}
 	return true;

+ 22 - 3
lib/battle/CBattleInfoCallback.h

@@ -131,11 +131,30 @@ public:
 	bool isInTacticRange(BattleHex dest) const;
 	si8 battleGetTacticDist() const; //returns tactic distance for calling player or 0 if this player is not in tactic phase (for ALL_KNOWING actual distance for tactic side)
 
-	AttackableTiles getPotentiallyAttackableHexes(const  battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const; //TODO: apply rotation to two-hex attacker
+	AttackableTiles getPotentiallyAttackableHexes(
+		const  battle::Unit* attacker,
+		const  battle::Unit* defender,
+		BattleHex destinationTile,
+		BattleHex attackerPos,
+		BattleHex defenderPos) const; //TODO: apply rotation to two-hex attacker
+
+	AttackableTiles getPotentiallyAttackableHexes(
+		const  battle::Unit * attacker,
+		BattleHex destinationTile,
+		BattleHex attackerPos) const;
+
 	AttackableTiles getPotentiallyShootableHexes(const  battle::Unit* attacker, BattleHex destinationTile, BattleHex attackerPos) const;
-	std::vector<const battle::Unit *> getAttackedBattleUnits(const battle::Unit* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos = BattleHex::INVALID) const; //calculates range of multi-hex attacks
+
+	std::vector<const battle::Unit *> getAttackedBattleUnits(
+		const battle::Unit* attacker,
+		const  battle::Unit * defender,
+		BattleHex destinationTile,
+		bool rangedAttack,
+		BattleHex attackerPos = BattleHex::INVALID,
+		BattleHex defenderPos = BattleHex::INVALID) const; //calculates range of multi-hex attacks
+	
 	std::set<const CStack*> getAttackedCreatures(const CStack* attacker, BattleHex destinationTile, bool rangedAttack, BattleHex attackerPos = BattleHex::INVALID) const; //calculates range of multi-hex attacks
-	bool isToReverse(const battle::Unit * attacker, const battle::Unit * defender) const; //determines if attacker standing at attackerHex should reverse in order to attack defender
+	bool isToReverse(const battle::Unit * attacker, const battle::Unit * defender, BattleHex attackerHex = BattleHex::INVALID, BattleHex defenderHex = BattleHex::INVALID) const; //determines if attacker standing at attackerHex should reverse in order to attack defender
 
 	ReachabilityInfo getReachability(const battle::Unit * unit) const;
 	ReachabilityInfo getReachability(const ReachabilityInfo::Parameters & params) const;

+ 205 - 0
test/battle/CBattleInfoCallbackTest.cpp

@@ -40,11 +40,32 @@ public:
 		bonusFake.addNewBonus(b);
 	}
 
+	void addCreatureAbility(BonusType bonusType)
+	{
+		addNewBonus(
+			std::make_shared<Bonus>(
+				BonusDuration::PERMANENT,
+				bonusType,
+				BonusSource::CREATURE_ABILITY,
+				0,
+				CreatureID(0)));
+	}
+
 	void makeAlive()
 	{
 		EXPECT_CALL(*this, alive()).WillRepeatedly(Return(true));
 	}
 
+	void setupPoisition(BattleHex pos)
+	{
+		EXPECT_CALL(*this, getPosition()).WillRepeatedly(Return(pos));
+	}
+
+	void makeDoubleWide()
+	{
+		EXPECT_CALL(*this, doubleWide()).WillRepeatedly(Return(true));
+	}
+
 	void makeWarMachine()
 	{
 		addNewBonus(std::make_shared<Bonus>(BonusDuration::PERMANENT, BonusType::SIEGE_WEAPON, BonusSource::CREATURE_ABILITY, 1, BonusSourceID()));
@@ -183,6 +204,190 @@ public:
 	}
 };
 
+class AttackableHexesTest : public CBattleInfoCallbackTest
+{
+public:
+	UnitFake & addRegularMelee(BattleHex hex, uint8_t side)
+	{
+		auto & unit = unitsFake.add(side);
+
+		unit.makeAlive();
+		unit.setDefaultState();
+		unit.setupPoisition(hex);
+		unit.redirectBonusesToFake();
+
+		return unit;
+	}
+
+	UnitFake & addDragon(BattleHex hex, uint8_t side)
+	{
+		auto & unit = addRegularMelee(hex, side);
+
+		unit.addCreatureAbility(BonusType::TWO_HEX_ATTACK_BREATH);
+		unit.makeDoubleWide();
+
+		return unit;
+	}
+
+	Units getAttackedUnits(UnitFake & attacker, UnitFake & defender, BattleHex defenderHex)
+	{
+		startBattle();
+		redirectUnitsToFake();
+
+		return subject.getAttackedBattleUnits(
+			&attacker, &defender,
+			defenderHex, false,
+			attacker.getPosition(),
+			defender.getPosition());
+	}
+};
+
+TEST_F(AttackableHexesTest, DragonRightRegular_RightHorithontalBreath)
+{
+	// X A D #
+	UnitFake & attacker = addDragon(35, 0);
+	UnitFake & defender = addRegularMelee(36, 1);
+	UnitFake & next = addRegularMelee(37, 1);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonDragonBottomRightHead_BottomRightBreathFromHead)
+{
+	// X A
+	//    D X		target D
+	//     #
+	UnitFake & attacker = addDragon(35, 0);
+	UnitFake & defender = addDragon(attacker.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+	
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonDragonVerticalDownHead_VerticalDownBreathFromHead)
+{
+	// X A
+	//  D X		target D
+	//   #
+	UnitFake & attacker = addDragon(35, 0);
+	UnitFake & defender = addDragon(attacker.getPosition().cloneInDirection(BattleHex::BOTTOM_LEFT), 1);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonDragonVerticalDownHeadReverse_VerticalDownBreathFromHead)
+{
+	//  A X
+	// X D		target D
+	//  #
+	UnitFake & attacker = addDragon(36, 1);
+	UnitFake & defender = addDragon(attacker.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 0);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_LEFT), 0);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonDragonVerticalDownBack_VerticalDownBreath)
+{
+	//  X A
+	// D X		target X
+	//  #
+	UnitFake & attacker = addDragon(37, 0);
+	UnitFake & defender = addDragon(attacker.occupiedHex().cloneInDirection(BattleHex::BOTTOM_LEFT), 1);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.occupiedHex());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonDragonHeadBottomRight_BottomRightBreathFromHead)
+{
+	//  X A
+	// D X		target D
+	//  #
+	UnitFake & attacker = addDragon(37, 0);
+	UnitFake & defender = addDragon(attacker.occupiedHex().cloneInDirection(BattleHex::BOTTOM_LEFT), 1);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonVerticalDownDragonBackReverse_VerticalDownBreath)
+{
+	// A X
+	//  X D		target X
+	//   #
+	UnitFake & attacker = addDragon(36, 1);
+	UnitFake & defender = addDragon(attacker.occupiedHex().cloneInDirection(BattleHex::BOTTOM_RIGHT), 0);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_LEFT), 0);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.occupiedHex());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonRightBottomDragonHeadReverse_RightBottomBreathFromHeadHex)
+{
+	// A X
+	//  X D		target D
+	UnitFake & attacker = addDragon(36, 1);
+	UnitFake & defender = addDragon(attacker.occupiedHex().cloneInDirection(BattleHex::BOTTOM_RIGHT), 0);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_LEFT), 0);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.getPosition());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DragonLeftBottomDragonBackToBack_LeftBottomBreathFromBackHex)
+{
+	//    X A
+	// D X		target X
+	//  #
+	UnitFake & attacker = addDragon(8, 0);
+	UnitFake & defender = addDragon(attacker.occupiedHex().cloneInDirection(BattleHex::BOTTOM_LEFT).cloneInDirection(BattleHex::LEFT), 1);
+	UnitFake & next = addRegularMelee(defender.getPosition().cloneInDirection(BattleHex::BOTTOM_RIGHT), 1);
+
+	auto attacked = getAttackedUnits(attacker, defender, defender.occupiedHex());
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
+TEST_F(AttackableHexesTest, DefenderPositionOverride_BreathCountsHypoteticDefenderPosition)
+{
+	//  # N
+	// X D		target D
+	//  A X
+	UnitFake & attacker = addDragon(35, 1);
+	UnitFake & defender = addDragon(8, 0);
+	UnitFake & next = addDragon(2, 0);
+
+	startBattle();
+	redirectUnitsToFake();
+
+	auto attacked = subject.getAttackedBattleUnits(
+		&attacker,
+		&defender,
+		19,
+		false,
+		attacker.getPosition(),
+		19);
+
+	EXPECT_TRUE(vstd::contains(attacked, &next));
+}
+
 class BattleFinishedTest : public CBattleInfoCallbackTest
 {
 public: