Ver código fonte

Merge pull request #3098 from vcmi/battle-ai-movement-fix

Battle ai movement fix
Andrii Danylchenko 2 anos atrás
pai
commit
7531400f53

+ 30 - 5
AI/BattleAI/AttackPossibility.cpp

@@ -114,6 +114,23 @@ float AttackPossibility::attackValue() const
 	return damageDiff();
 }
 
+float hpFunction(uint64_t unitHealthStart, uint64_t unitHealthEnd, uint64_t maxHealth)
+{
+	float ratioStart = static_cast<float>(unitHealthStart) / maxHealth;
+	float ratioEnd = static_cast<float>(unitHealthEnd) / maxHealth;
+	float base = 0.666666f;
+
+	// reduce from max to 0 must be 1. 
+	// 10 hp from end costs bit more than 10 hp from start because our goal is to kill unit, not just hurt it
+	// ********** 2 * base - ratioStart *********
+	// *                                                              *
+	// *        height = ratioStart - ratioEnd         *
+	// *                                                                  *
+	// ******************** 2 * base - ratioEnd ******
+	// S = (a + b) * h / 2
+	return (base * (4 - ratioStart - ratioEnd)) * (ratioStart - ratioEnd) / 2 ;
+}
+
 /// <summary>
 /// How enemy damage will be reduced by this attack
 /// Half bounty for kill, half for making damage equal to enemy health
@@ -127,6 +144,7 @@ float AttackPossibility::calculateDamageReduce(
 	std::shared_ptr<CBattleInfoCallback> state)
 {
 	const float HEALTH_BOUNTY = 0.5;
+	const float KILL_BOUNTY = 0.5;
 
 	// FIXME: provide distance info for Jousting bonus
 	auto attackerUnitForMeasurement = attacker;
@@ -157,13 +175,20 @@ float AttackPossibility::calculateDamageReduce(
 	auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state);
 	auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0);
 	auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount();
+	auto exceedingDamage = (damageDealt % maxHealth);
+	float hpValue = (damageDealt / maxHealth);
 	
-	// lets use cached maxHealth here instead of getAvailableHealth
-	auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth;
-	auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
-	auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
+	if(defender->getFirstHPleft() >= exceedingDamage)
+	{
+		hpValue += hpFunction(defender->getFirstHPleft(), defender->getFirstHPleft() - exceedingDamage, maxHealth);
+	}
+	else
+	{
+		hpValue += hpFunction(defender->getFirstHPleft(), 0, maxHealth);
+		hpValue += hpFunction(maxHealth, maxHealth + defender->getFirstHPleft() - exceedingDamage, maxHealth);
+	}
 
-	return damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY);
+	return damagePerEnemy * (enemiesKilled * KILL_BOUNTY + hpValue * HEALTH_BOUNTY);
 }
 
 int64_t AttackPossibility::evaluateBlockedShootersDmg(

+ 10 - 2
AI/BattleAI/BattleEvaluator.cpp

@@ -149,7 +149,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 				bestAttack.attack.attacker->speed(0, true),
 				bestAttack.defenderDamageReduce,
 				bestAttack.attackerDamageReduce,
-				bestAttack.attackValue()
+				score
 			);
 
 			if (moveTarget.scorePerTurn <= score)
@@ -190,6 +190,14 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 
 		if(stack->waited())
 		{
+			logAi->debug(
+				"Moving %s towards hex %s[%d], score: %2f/%2f",
+				stack->getDescription(),
+				moveTarget.cachedAttack->attack.defender->getDescription(),
+				moveTarget.cachedAttack->attack.defender->getPosition().hex,
+				moveTarget.score,
+				moveTarget.scorePerTurn);
+
 			return goTowardsNearest(stack, moveTarget.positions);
 		}
 		else
@@ -572,7 +580,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 				}
 				else
 				{
-					ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state);
+					ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
 				}
 
 				for(auto unit : allUnits)

+ 213 - 78
AI/BattleAI/BattleExchangeVariant.cpp

@@ -18,10 +18,8 @@ AttackerValue::AttackerValue()
 }
 
 MoveTarget::MoveTarget()
-	: positions(), cachedAttack()
+	: positions(), cachedAttack(), score(EvaluationResult::INEFFECTIVE_SCORE), scorePerTurn(EvaluationResult::INEFFECTIVE_SCORE)
 {
-	score = EvaluationResult::INEFFECTIVE_SCORE;
-	scorePerTurn = EvaluationResult::INEFFECTIVE_SCORE;
 	turnsToRich = 1;
 }
 
@@ -58,7 +56,7 @@ float BattleExchangeVariant::trackAttack(
 				auto attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), unitToUpdate.get(), retaliationDamage, damageCache, hb);
 
 				attackValue -= attackerDamageReduce;
-				dpsScore -= attackerDamageReduce * negativeEffectMultiplier;
+				dpsScore.ourDamageReduce += attackerDamageReduce;
 				attackerValue[unitToUpdate->unitId()].isRetalitated = true;
 
 				unitToUpdate->damage(retaliationDamage);
@@ -80,7 +78,7 @@ float BattleExchangeVariant::trackAttack(
 				auto collateralDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), collateralDamage, damageCache, hb);
 
 				attackValue -= collateralDamageReduce;
-				dpsScore -= collateralDamageReduce * negativeEffectMultiplier;
+				dpsScore.ourDamageReduce += collateralDamageReduce;
 
 				unitToUpdate->damage(collateralDamage);
 
@@ -101,7 +99,7 @@ float BattleExchangeVariant::trackAttack(
 			float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), unitToUpdate.get(), attackDamage, damageCache, hb);
 
 			attackValue += defenderDamageReduce;
-			dpsScore += defenderDamageReduce * positiveEffectMultiplier;
+			dpsScore.enemyDamageReduce += defenderDamageReduce;
 			attackerValue[attacker->unitId()].value += defenderDamageReduce;
 
 			unitToUpdate->damage(attackDamage);
@@ -118,8 +116,12 @@ float BattleExchangeVariant::trackAttack(
 		}
 	}
 
+#if BATTLE_TRACE_LEVEL >= 1
+	logAi->trace("ap shooters blocking: %lld", ap.shootersBlockedDmg);
+#endif
+
 	attackValue += ap.shootersBlockedDmg;
-	dpsScore += ap.shootersBlockedDmg * positiveEffectMultiplier;
+	dpsScore.enemyDamageReduce += ap.shootersBlockedDmg;
 	attacker->afterAttack(ap.attack.shooting, false);
 
 	return attackValue;
@@ -156,11 +158,11 @@ float BattleExchangeVariant::trackAttack(
 
 		if(isOurAttack)
 		{
-			dpsScore += defenderDamageReduce * positiveEffectMultiplier;
+			dpsScore.enemyDamageReduce += defenderDamageReduce;
 			attackerValue[attacker->unitId()].value += defenderDamageReduce;
 		}
 		else
-			dpsScore -= defenderDamageReduce * negativeEffectMultiplier;
+			dpsScore.ourDamageReduce += defenderDamageReduce;
 
 		defender->damage(attackDamage);
 		attacker->afterAttack(shooting, false);
@@ -182,12 +184,12 @@ float BattleExchangeVariant::trackAttack(
 
 		if(isOurAttack)
 		{
-			dpsScore -= attackerDamageReduce * negativeEffectMultiplier;
+			dpsScore.ourDamageReduce += attackerDamageReduce;
 			attackerValue[attacker->unitId()].isRetalitated = true;
 		}
 		else
 		{
-			dpsScore += attackerDamageReduce * positiveEffectMultiplier;
+			dpsScore.enemyDamageReduce += attackerDamageReduce;
 			attackerValue[defender->unitId()].value += attackerDamageReduce;
 		}
 
@@ -200,13 +202,18 @@ float BattleExchangeVariant::trackAttack(
 #if BATTLE_TRACE_LEVEL>=1
 	if(!score)
 	{
-		logAi->trace("Attack has zero score d:%2f a:%2f", defenderDamageReduce, attackerDamageReduce);
+		logAi->trace("Attack has zero score def:%2f att:%2f", defenderDamageReduce, attackerDamageReduce);
 	}
 #endif
 
 	return score;
 }
 
+float BattleExchangeEvaluator::scoreValue(const BattleScore & score) const
+{
+	return score.enemyDamageReduce * getPositiveEffectMultiplier() - score.ourDamageReduce * getNegativeEffectMultiplier();
+}
+
 EvaluationResult BattleExchangeEvaluator::findBestTarget(
 	const battle::Unit * activeStack,
 	PotentialTargets & targets,
@@ -215,7 +222,7 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 {
 	EvaluationResult result(targets.bestAction());
 
-	if(!activeStack->waited())
+	if(!activeStack->waited() && !activeStack->acquireState()->hadMorale)
 	{
 #if BATTLE_TRACE_LEVEL>=1
 		logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
@@ -230,13 +237,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 		for(auto & ap : targets.possibleAttacks)
 		{
-			float score = calculateExchange(ap, targets, damageCache, hbWaited);
+			float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
 
 			if(score > result.score)
 			{
 				result.score = score;
 				result.bestAttack = ap;
 				result.wait = true;
+
+#if BATTLE_TRACE_LEVEL >= 1
+				logAi->trace("New high score %2f", result.score);
+#endif
 			}
 		}
 	}
@@ -247,15 +258,25 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 	updateReachabilityMap(hb);
 
+	if(result.bestAttack.attack.shooting && hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
+	{
+		if(!canBeHitThisTurn(result.bestAttack))
+			return result; // lets wait
+	}
+
 	for(auto & ap : targets.possibleAttacks)
 	{
-		float score = calculateExchange(ap, targets, damageCache, hb);
+		float score = evaluateExchange(ap, 0, targets, damageCache, hb);
 
-		if(score >= result.score)
+		if(score > result.score || (score == result.score && result.wait))
 		{
 			result.score = score;
 			result.bestAttack = ap;
 			result.wait = false;
+
+#if BATTLE_TRACE_LEVEL >= 1
+			logAi->trace("New high score %2f", result.score);
+#endif
 		}
 	}
 
@@ -269,7 +290,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 	std::shared_ptr<HypotheticBattle> hb)
 {
 	MoveTarget result;
-	BattleExchangeVariant ev(getPositiveEffectMultiplier(), getNegativeEffectMultiplier());
+	BattleExchangeVariant ev;
 
 	if(targets.unreachableEnemies.empty())
 		return result;
@@ -301,6 +322,12 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 
 		auto turnsToRich = (distance - 1) / speed + 1;
 		auto hexes = closestStack->getSurroundingHexes();
+		auto enemySpeed = closestStack->speed();
+		auto speedRatio = speed / static_cast<float>(enemySpeed);
+		auto multiplier = speedRatio > 1 ? 1 : speedRatio;
+
+		if(enemy->canShoot())
+			multiplier *= 1.5f;
 
 		for(auto hex : hexes)
 		{
@@ -310,13 +337,13 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 
 			attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
 
-			auto score = calculateExchange(attack, targets, damageCache, hb);
-			auto scorePerTurn = score / turnsToRich;
+			auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
+			auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(multiplier / turnsToRich), score.ourDamageReduce);
 
-			if(result.scorePerTurn < scorePerTurn)
+			if(result.scorePerTurn < scoreValue(scorePerTurn))
 			{
-				result.scorePerTurn = scorePerTurn;
-				result.score = score;
+				result.scorePerTurn = scoreValue(scorePerTurn);
+				result.score = scoreValue(score);
 				result.positions = closestStack->getAttackableHexes(activeStack);
 				result.cachedAttack = attack;
 				result.turnsToRich = turnsToRich;
@@ -357,21 +384,23 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(cons
 	return checkedStacks;
 }
 
-std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
+ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
 	const AttackPossibility & ap,
+	uint8_t turn,
 	PotentialTargets & targets,
 	std::shared_ptr<HypotheticBattle> hb)
 {
-	auto hexes = ap.attack.defender->getHexes();
+	ReachabilityData result;
+
+	auto hexes = ap.attack.defender->getSurroundingHexes();
 
 	if(!ap.attack.shooting) hexes.push_back(ap.from);
 
-	std::vector<const battle::Unit *> exchangeUnits;
 	std::vector<const battle::Unit *> allReachableUnits;
 
 	for(auto hex : hexes)
 	{
-		vstd::concatenate(allReachableUnits, reachabilityMap[hex]);
+		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
 	}
 
 	vstd::removeDuplicates(allReachableUnits);
@@ -404,7 +433,28 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
 		logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
 #endif
 
-		return exchangeUnits;
+		return result;
+	}
+
+	for(auto unit : allReachableUnits)
+	{
+		auto accessible = !unit->canShoot();
+
+		if(!accessible)
+		{
+			for(auto hex : unit->getSurroundingHexes())
+			{
+				if(ap.attack.defender->coversPos(hex))
+				{
+					accessible = true;
+				}
+			}
+		}
+
+		if(accessible)
+			result.melleeAccessible.push_back(unit);
+		else
+			result.shooters.push_back(unit);
 	}
 
 	for(int turn = 0; turn < turnOrder.size(); turn++)
@@ -412,20 +462,47 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
 		for(auto unit : turnOrder[turn])
 		{
 			if(vstd::contains(allReachableUnits, unit))
-				exchangeUnits.push_back(unit);
+				result.units.push_back(unit);
 		}
 	}
 
-	vstd::erase_if(exchangeUnits, [&](const battle::Unit * u) -> bool
+	vstd::erase_if(result.units, [&](const battle::Unit * u) -> bool
 		{
 			return !hb->battleGetUnitByID(u->unitId())->alive();
 		});
 
-	return exchangeUnits;
+	return result;
 }
 
-float BattleExchangeEvaluator::calculateExchange(
+float BattleExchangeEvaluator::evaluateExchange(
 	const AttackPossibility & ap,
+	uint8_t turn,
+	PotentialTargets & targets,
+	DamageCache & damageCache,
+	std::shared_ptr<HypotheticBattle> hb)
+{
+	if(ap.from.hex == 127)
+	{
+		logAi->trace("x");
+	}
+
+	BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
+
+#if BATTLE_TRACE_LEVEL >= 1
+	logAi->trace(
+		"calculateExchange score +%2f -%2fx%2f = %2f",
+		score.enemyDamageReduce,
+		score.ourDamageReduce,
+		getNegativeEffectMultiplier(),
+		scoreValue(score));
+#endif
+
+	return scoreValue(score);
+}
+
+BattleScore BattleExchangeEvaluator::calculateExchange(
+	const AttackPossibility & ap,
+	uint8_t turn,
 	PotentialTargets & targets,
 	DamageCache & damageCache,
 	std::shared_ptr<HypotheticBattle> hb)
@@ -438,7 +515,7 @@ float BattleExchangeEvaluator::calculateExchange(
 		&& cb->battleGetGateState() == EGateState::BLOCKED
 		&& ap.attack.defender->coversPos(BattleHex::GATE_BRIDGE))
 	{
-		return EvaluationResult::INEFFECTIVE_SCORE;
+		return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
 	}
 
 	std::vector<const battle::Unit *> ourStacks;
@@ -447,27 +524,32 @@ float BattleExchangeEvaluator::calculateExchange(
 	if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
 		enemyStacks.push_back(ap.attack.defender);
 
-	std::vector<const battle::Unit *> exchangeUnits = getExchangeUnits(ap, targets, hb);
+	ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb);
 
-	if(exchangeUnits.empty())
+	if(exchangeUnits.units.empty())
 	{
-		return 0;
+		return BattleScore();
 	}
 
 	auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
-	BattleExchangeVariant v(getPositiveEffectMultiplier(), getNegativeEffectMultiplier());
+	BattleExchangeVariant v;
 
-	for(auto unit : exchangeUnits)
+	for(auto unit : exchangeUnits.units)
 	{
 		if(unit->isTurret())
 			continue;
 
 		bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
 		auto & attackerQueue = isOur ? ourStacks : enemyStacks;
+		auto u = exchangeBattle->getForUpdate(unit->unitId());
 
-		if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit))
+		if(u->alive() && !vstd::contains(attackerQueue, unit))
 		{
 			attackerQueue.push_back(unit);
+
+#if BATTLE_TRACE_LEVEL
+			logAi->trace("Exchanging: %s", u->getDescription());
+#endif
 		}
 	}
 
@@ -476,12 +558,12 @@ float BattleExchangeEvaluator::calculateExchange(
 	vstd::removeDuplicates(melleeAttackers);
 	vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
 		{
-			return !cb->battleCanShoot(u);
+			return cb->battleCanShoot(u);
 		});
 
 	bool canUseAp = true;
 
-	for(auto activeUnit : exchangeUnits)
+	for(auto activeUnit : exchangeUnits.units)
 	{
 		bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, activeUnit, true);
 		battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
@@ -515,15 +597,22 @@ float BattleExchangeEvaluator::calculateExchange(
 					true);
 
 #if BATTLE_TRACE_LEVEL>=1
-				logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), u->getDescription(), score);
+				logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
 #endif
 
 				return score;
 			};
 
-			if(!oppositeQueue.empty())
+			auto unitsInOppositeQueueExceptInaccessible = oppositeQueue;
+
+			vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u)->bool
+				{
+					return vstd::contains(exchangeUnits.shooters, u);
+				});
+
+			if(!unitsInOppositeQueueExceptInaccessible.empty())
 			{
-				targetUnit = *vstd::maxElementByFun(oppositeQueue, estimateAttack);
+				targetUnit = *vstd::maxElementByFun(unitsInOppositeQueueExceptInaccessible, estimateAttack);
 			}
 			else
 			{
@@ -591,10 +680,20 @@ float BattleExchangeEvaluator::calculateExchange(
 
 	// avoid blocking path for stronger stack by weaker stack
 	// the method checks if all stacks can be placed around enemy
-	v.adjustPositions(melleeAttackers, ap, reachabilityMap);
+	std::map<BattleHex, battle::Units> reachabilityMap;
+
+	auto hexes = ap.attack.defender->getSurroundingHexes();
+
+	for(auto hex : hexes)
+		reachabilityMap[hex] = getOneTurnReachableUnits(turn, hex);
+
+	if(!ap.attack.shooting)
+	{
+		v.adjustPositions(melleeAttackers, ap, reachabilityMap);
+	}
 
 #if BATTLE_TRACE_LEVEL>=1
-	logAi->trace("Exchange score: %2f", v.getScore());
+	logAi->trace("Exchange score: enemy: %2f, our -%2f", v.getScore().enemyDamageReduce, v.getScore().ourDamageReduce);
 #endif
 
 	return v.getScore();
@@ -618,11 +717,8 @@ void BattleExchangeVariant::adjustPositions(
 			return attackerValue[u1->unitId()].value > attackerValue[u2->unitId()].value;
 		});
 
-	if(!ap.attack.shooting)
-	{
-		vstd::erase_if_present(hexes, ap.from);
-		vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos));
-	}
+	vstd::erase_if_present(hexes, ap.from);
+	vstd::erase_if_present(hexes, ap.attack.attacker->occupiedHex(ap.attack.attackerPos));
 
 	float notRealizedDamage = 0;
 
@@ -662,22 +758,58 @@ void BattleExchangeVariant::adjustPositions(
 
 	if(notRealizedDamage > ap.attackValue() && notRealizedDamage > attackerValue[ap.attack.attacker->unitId()].value)
 	{
-		dpsScore = EvaluationResult::INEFFECTIVE_SCORE;
+		dpsScore = BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
 	}
 }
 
-void BattleExchangeEvaluator::updateReachabilityMap(	std::shared_ptr<HypotheticBattle> hb)
+bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
+{
+	for(auto pos : ap.attack.attacker->getSurroundingHexes())
+	{
+		for(auto u : reachabilityMap[pos])
+		{
+			if(u->unitSide() != ap.attack.attacker->unitSide())
+			{
+				return true;
+			}
+		}
+	}
+
+	return false;
+}
+
+void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
 {
 	const int TURN_DEPTH = 2;
 
 	turnOrder.clear();
-	
+
 	hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
-	reachabilityMap.clear();
 
-	for(int turn = 0; turn < turnOrder.size(); turn++)
+	for(auto turn : turnOrder)
 	{
-		auto & turnQueue = turnOrder[turn];
+		for(auto u : turn)
+		{
+			if(!vstd::contains(reachabilityCache, u->unitId()))
+			{
+				reachabilityCache[u->unitId()] = hb->getReachability(u);
+			}
+		}
+	}
+
+	for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+	{
+		reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
+	}
+}
+
+std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
+{
+	std::vector<const battle::Unit *> result;
+
+	for(int i = 0; i < turnOrder.size(); i++, turn++)
+	{
+		auto & turnQueue = turnOrder[i];
 		HypotheticBattle turnBattle(env.get(), cb);
 
 		for(const battle::Unit * unit : turnQueue)
@@ -685,46 +817,49 @@ void BattleExchangeEvaluator::updateReachabilityMap(	std::shared_ptr<HypotheticB
 			if(unit->isTurret())
 				continue;
 
-			auto unitSpeed = unit->speed(turn);
-
 			if(turnBattle.battleCanShoot(unit))
 			{
-				for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
-				{
-					reachabilityMap[hex].push_back(unit);
-				}
+				result.push_back(unit);
 
 				continue;
 			}
 
-			auto unitReachability = turnBattle.getReachability(unit);
+			auto unitSpeed = unit->speed(turn);
+			auto radius = unitSpeed * (turn + 1);
 
-			for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
+			ReachabilityInfo unitReachability = vstd::getOrCompute(
+				reachabilityCache,
+				unit->unitId(),
+				[&](ReachabilityInfo & data)
+				{
+					data = turnBattle.getReachability(unit);
+				});
+
+			bool reachable = unitReachability.distances[hex] <= radius;
+
+			if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
 			{
-				bool reachable = unitReachability.distances[hex] <= unitSpeed;
+				const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
 
-				if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
+				if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
 				{
-					const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
-
-					if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
+					for(BattleHex neighbor : hex.neighbouringTiles())
 					{
-						for(BattleHex neighbor : hex.neighbouringTiles())
-						{
-							reachable = unitReachability.distances[neighbor] <= unitSpeed;
+						reachable = unitReachability.distances[neighbor] <= radius;
 
-							if(reachable) break;
-						}
+						if(reachable) break;
 					}
 				}
+			}
 
-				if(reachable)
-				{
-					reachabilityMap[hex].push_back(unit);
-				}
+			if(reachable)
+			{
+				result.push_back(unit);
 			}
 		}
 	}
+
+	return result;
 }
 
 // avoid blocking path for stronger stack by weaker stack

+ 68 - 11
AI/BattleAI/BattleExchangeVariant.h

@@ -14,6 +14,34 @@
 #include "PotentialTargets.h"
 #include "StackWithBonuses.h"
 
+struct BattleScore
+{
+	float ourDamageReduce;
+	float enemyDamageReduce;
+
+	BattleScore(float enemyDamageReduce, float ourDamageReduce)
+		:enemyDamageReduce(enemyDamageReduce), ourDamageReduce(ourDamageReduce)
+	{
+	}
+
+	BattleScore() : BattleScore(0, 0) {}
+
+	float value()
+	{
+		return enemyDamageReduce - ourDamageReduce;
+	}
+
+	BattleScore  operator+(BattleScore & other)
+	{
+		BattleScore result = *this;
+
+		result.ourDamageReduce += other.ourDamageReduce;
+		result.enemyDamageReduce += other.enemyDamageReduce;
+
+		return result;
+	}
+};
+
 struct AttackerValue
 {
 	float value;
@@ -59,8 +87,8 @@ struct EvaluationResult
 class BattleExchangeVariant
 {
 public:
-	BattleExchangeVariant(float positiveEffectMultiplier, float negativeEffectMultiplier)
-		: dpsScore(0), positiveEffectMultiplier(positiveEffectMultiplier), negativeEffectMultiplier(negativeEffectMultiplier) {}
+	BattleExchangeVariant()
+		: dpsScore() {}
 
 	float trackAttack(
 		const AttackPossibility & ap,
@@ -76,7 +104,7 @@ public:
 		std::shared_ptr<HypotheticBattle> hb,
 		bool evaluateOnly = false);
 
-	float getScore() const { return dpsScore; }
+	const BattleScore & getScore() const { return dpsScore; }
 
 	void adjustPositions(
 		std::vector<const battle::Unit *> attackers,
@@ -84,27 +112,48 @@ public:
 		std::map<BattleHex, battle::Units> & reachabilityMap);
 
 private:
-	float positiveEffectMultiplier;
-	float negativeEffectMultiplier;
-	float dpsScore;
+	BattleScore dpsScore;
 	std::map<uint32_t, AttackerValue> attackerValue;
 };
 
+struct ReachabilityData
+{
+	std::vector<const battle::Unit *> units;
+
+	// shooters which are within mellee attack and mellee units
+	std::vector<const battle::Unit *> melleeAccessible;
+
+	// far shooters
+	std::vector<const battle::Unit *> shooters;
+};
+
 class BattleExchangeEvaluator
 {
 private:
 	std::shared_ptr<CBattleInfoCallback> cb;
 	std::shared_ptr<Environment> env;
+	std::map<uint32_t, ReachabilityInfo> reachabilityCache;
 	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
 	std::vector<battle::Units> turnOrder;
 	float negativeEffectMultiplier;
 
+	float scoreValue(const BattleScore & score) const;
+
+	BattleScore calculateExchange(
+		const AttackPossibility & ap,
+		uint8_t turn,
+		PotentialTargets & targets,
+		DamageCache & damageCache,
+		std::shared_ptr<HypotheticBattle> hb);
+
+	bool canBeHitThisTurn(const AttackPossibility & ap);
+
 public:
 	BattleExchangeEvaluator(
 		std::shared_ptr<CBattleInfoCallback> cb,
 		std::shared_ptr<Environment> env,
 		float strengthRatio): cb(cb), env(env) {
-		negativeEffectMultiplier = strengthRatio;
+		negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio;
 	}
 
 	EvaluationResult findBestTarget(
@@ -113,14 +162,22 @@ public:
 		DamageCache & damageCache,
 		std::shared_ptr<HypotheticBattle> hb);
 
-	float calculateExchange(
+	float evaluateExchange(
 		const AttackPossibility & ap,
+		uint8_t turn,
 		PotentialTargets & targets,
 		DamageCache & damageCache,
 		std::shared_ptr<HypotheticBattle> hb);
 
+	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
 	void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
-	std::vector<const battle::Unit *> getExchangeUnits(const AttackPossibility & ap, PotentialTargets & targets, std::shared_ptr<HypotheticBattle> hb);
+
+	ReachabilityData getExchangeUnits(
+		const AttackPossibility & ap,
+		uint8_t turn,
+		PotentialTargets & targets,
+		std::shared_ptr<HypotheticBattle> hb);
+
 	bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);
 
 	MoveTarget findMoveTowardsUnreachable(
@@ -131,6 +188,6 @@ public:
 
 	std::vector<const battle::Unit *> getAdjacentUnits(const battle::Unit * unit);
 
-	float getPositiveEffectMultiplier() { return 1; }
-	float getNegativeEffectMultiplier() { return negativeEffectMultiplier; }
+	float getPositiveEffectMultiplier() const { return 1; }
+	float getNegativeEffectMultiplier() const { return negativeEffectMultiplier; }
 };

+ 5 - 4
AI/Nullkiller/Engine/Nullkiller.cpp

@@ -274,6 +274,7 @@ void Nullkiller::makeTurn()
 
 		bestTask = choseBestTask(bestTasks);
 
+		std::string taskDescription = bestTask->toString();
 		HeroPtr hero = bestTask->getHero();
 		HeroRole heroRole = HeroRole::MAIN;
 
@@ -292,7 +293,7 @@ void Nullkiller::makeTurn()
 
 			logAi->trace(
 				"Goal %s has low priority %f so decreasing  scan depth to gain performance.",
-				bestTask->toString(),
+				taskDescription,
 				bestTask->priority);
 		}
 
@@ -308,7 +309,7 @@ void Nullkiller::makeTurn()
 			{
 				logAi->trace(
 					"Goal %s has too low priority %f so increasing scan depth to full.",
-					bestTask->toString(),
+					taskDescription,
 					bestTask->priority);
 
 				scanDepth = ScanDepth::ALL_FULL;
@@ -316,7 +317,7 @@ void Nullkiller::makeTurn()
 				continue;
 			}
 
-			logAi->trace("Goal %s has too low priority. It is not worth doing it. Ending turn.", bestTask->toString());
+			logAi->trace("Goal %s has too low priority. It is not worth doing it. Ending turn.", taskDescription);
 
 			return;
 		}
@@ -325,7 +326,7 @@ void Nullkiller::makeTurn()
 
 		if(i == MAXPASS)
 		{
-			logAi->error("Goal %s exceeded maxpass. Terminating AI turn.", bestTask->toString());
+			logAi->error("Goal %s exceeded maxpass. Terminating AI turn.", taskDescription);
 		}
 	}
 }