浏览代码

BattleAI: positive/negative effect multiplier

Andrii Danylchenko 2 年之前
父节点
当前提交
dc88f14e0b

+ 18 - 5
AI/BattleAI/AttackPossibility.cpp

@@ -103,6 +103,12 @@ int64_t AttackPossibility::damageDiff() const
 	return defenderDamageReduce - attackerDamageReduce - collateralDamageReduce + shootersBlockedDmg;
 }
 
+int64_t AttackPossibility::damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const
+{
+	return positiveEffectMultiplier * (defenderDamageReduce + shootersBlockedDmg)
+		- negativeEffectMultiplier * (attackerDamageReduce + collateralDamageReduce);
+}
+
 int64_t AttackPossibility::attackValue() const
 {
 	return damageDiff();
@@ -121,9 +127,6 @@ int64_t AttackPossibility::calculateDamageReduce(
 	std::shared_ptr<CBattleInfoCallback> state)
 {
 	const float HEALTH_BOUNTY = 0.5;
-	const float KILL_BOUNTY = 1.0 - HEALTH_BOUNTY;
-
-	vstd::amin(damageDealt, defender->getAvailableHealth());
 
 	// FIXME: provide distance info for Jousting bonus
 	auto attackerUnitForMeasurement = attacker;
@@ -146,11 +149,21 @@ int64_t AttackPossibility::calculateDamageReduce(
 			attackerUnitForMeasurement = ourUnits.front();
 	}
 
+	auto maxHealth = defender->getMaxHealth();
+	auto availableHealth = defender->getFirstHPleft() + ((defender->getCount() - 1) * maxHealth);
+
+	vstd::amin(damageDealt, availableHealth);
+
 	auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state);
-	auto enemiesKilled = damageDealt / defender->getMaxHealth() + (damageDealt % defender->getMaxHealth() >= defender->getFirstHPleft() ? 1 : 0);
+	auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0);
 	auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount();
+	
+	// lets use cached maxHealth here instead of getAvailableHealth
+	auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth;
+	auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
+	auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
 
-	return (int64_t)(damagePerEnemy * (enemiesKilled * KILL_BOUNTY + damageDealt * HEALTH_BOUNTY / (double)defender->getMaxHealth()));
+	return (int64_t)(damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY));
 }
 
 int64_t AttackPossibility::evaluateBlockedShootersDmg(

+ 1 - 0
AI/BattleAI/AttackPossibility.h

@@ -55,6 +55,7 @@ public:
 
 	int64_t damageDiff() const;
 	int64_t attackValue() const;
+	int64_t damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const;
 
 	static AttackPossibility evaluate(
 		const BattleAttackInfo & attackInfo,

+ 24 - 673
AI/BattleAI/BattleAI.cpp

@@ -9,6 +9,7 @@
  */
 #include "StdInc.h"
 #include "BattleAI.h"
+#include "BattleEvaluator.h"
 #include "BattleExchangeVariant.h"
 
 #include "StackWithBonuses.h"
@@ -29,90 +30,6 @@
 #define LOGL(text) print(text)
 #define LOGFL(text, formattingEl) print(boost::str(boost::format(text) % formattingEl))
 
-enum class SpellTypes
-{
-	ADVENTURE, BATTLE, OTHER
-};
-
-SpellTypes spellType(const CSpell * spell)
-{
-	if(!spell->isCombat() || spell->isCreatureAbility())
-		return SpellTypes::OTHER;
-
-	if(spell->isOffensive() || spell->hasEffects() || spell->hasBattleEffects())
-		return SpellTypes::BATTLE;
-
-	return SpellTypes::OTHER;
-}
-
-class BattleEvaluator
-{
-	std::unique_ptr<PotentialTargets> targets;
-	std::shared_ptr<HypotheticBattle> hb;
-	BattleExchangeEvaluator scoreEvaluator;
-	std::shared_ptr<CBattleCallback> cb;
-	std::shared_ptr<Environment> env;
-	bool activeActionMade = false;
-	std::optional<AttackPossibility> cachedAttack;
-	PlayerColor playerID;
-	int side;
-	int64_t cachedScore;
-	DamageCache damageCache;
-	
-public:
-	BattleAction selectStackAction(const CStack * stack);
-	void attemptCastingSpell(const CStack * stack);
-	std::optional<PossibleSpellcast> findBestCreatureSpell(const CStack * stack);
-	BattleAction goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes);
-	std::vector<BattleHex> getBrokenWallMoatHexes() const;
-	void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only
-	void print(const std::string & text) const;
-
-	BattleEvaluator(std::shared_ptr<Environment> env, std::shared_ptr<CBattleCallback> cb, const battle::Unit * activeStack, PlayerColor playerID, int side)
-		:scoreEvaluator(cb, env), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb)
-	{
-		hb = std::make_shared<HypotheticBattle>(env.get(), cb);
-		damageCache.buildDamageCache(hb, side);
-
-		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
-	}
-
-	BattleEvaluator(
-		std::shared_ptr<Environment> env,
-		std::shared_ptr<CBattleCallback> cb,
-		std::shared_ptr<HypotheticBattle> hb,
-		DamageCache & damageCache,
-		const battle::Unit * activeStack,
-		PlayerColor playerID,
-		int side)
-		:scoreEvaluator(cb, env), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache)
-	{
-		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
-	}
-};
-
-std::vector<BattleHex> BattleEvaluator::getBrokenWallMoatHexes() const
-{
-	std::vector<BattleHex> result;
-
-	for(EWallPart wallPart : { EWallPart::BOTTOM_WALL, EWallPart::BELOW_GATE, EWallPart::OVER_GATE, EWallPart::UPPER_WALL })
-	{
-		auto state = cb->battleGetWallState(wallPart);
-
-		if(state != EWallState::DESTROYED)
-			continue;
-
-		auto wallHex = cb->wallPartToBattleHex((EWallPart)wallPart);
-		auto moatHex = wallHex.cloneInDirection(BattleHex::LEFT);
-
-		result.push_back(moatHex);
-	}
-
-	return result;
-}
-
 CBattleAI::CBattleAI()
 	: side(-1),
 	wasWaitingForRealize(false),
@@ -159,161 +76,22 @@ BattleAction CBattleAI::useHealingTent(const CStack *stack)
 		return BattleAction::makeHeal(stack, woundHpToStack.rbegin()->second); //last element of the woundHpToStack is the most wounded stack
 }
 
-std::optional<PossibleSpellcast> BattleEvaluator::findBestCreatureSpell(const CStack *stack)
-{
-	//TODO: faerie dragon type spell should be selected by server
-	SpellID creatureSpellToCast = cb->battleGetRandomStackSpell(CRandomGenerator::getDefault(), stack, CBattleInfoCallback::RANDOM_AIMED);
-	if(stack->hasBonusOfType(BonusType::SPELLCASTER) && stack->canCast() && creatureSpellToCast != SpellID::NONE)
-	{
-		const CSpell * spell = creatureSpellToCast.toSpell();
-
-		if(spell->canBeCast(getCbc().get(), spells::Mode::CREATURE_ACTIVE, stack))
-		{
-			std::vector<PossibleSpellcast> possibleCasts;
-			spells::BattleCast temp(getCbc().get(), stack, spells::Mode::CREATURE_ACTIVE, spell);
-			for(auto & target : temp.findPotentialTargets())
-			{
-				PossibleSpellcast ps;
-				ps.dest = target;
-				ps.spell = spell;
-				evaluateCreatureSpellcast(stack, ps);
-				possibleCasts.push_back(ps);
-			}
-
-			std::sort(possibleCasts.begin(), possibleCasts.end(), [&](const PossibleSpellcast & lhs, const PossibleSpellcast & rhs) { return lhs.value > rhs.value; });
-			if(!possibleCasts.empty() && possibleCasts.front().value > 0)
-			{
-				return possibleCasts.front();
-			}
-		}
-	}
-	return std::nullopt;
-}
-
-BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
-{
-	//evaluate casting spell for spellcasting stack
-	std::optional<PossibleSpellcast> bestSpellcast = findBestCreatureSpell(stack);
-
-	auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb);
-	auto score = EvaluationResult::INEFFECTIVE_SCORE;
-
-	if(targets->possibleAttacks.empty() && bestSpellcast.has_value())
-	{
-		activeActionMade = true;
-		return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
-	}
-
-	if(!targets->possibleAttacks.empty())
-	{
-#if BATTLE_TRACE_LEVEL>=1
-		logAi->trace("Evaluating attack for %s", stack->getDescription());
-#endif
-
-		auto evaluationResult = scoreEvaluator.findBestTarget(stack, *targets, damageCache, hb);
-		auto & bestAttack = evaluationResult.bestAttack;
-
-		cachedAttack = bestAttack;
-		cachedScore = evaluationResult.score;
-
-		//TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc.
-		if(bestSpellcast.has_value() && bestSpellcast->value > bestAttack.damageDiff())
-		{
-			// return because spellcast value is damage dealt and score is dps reduce
-			activeActionMade = true;
-			return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
-		}
-
-		if(evaluationResult.score > score)
-		{
-			score = evaluationResult.score;
-
-			logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld",
-				bestAttack.attackerState->unitType()->getJsonKey(),
-				bestAttack.affectedUnits[0]->unitType()->getJsonKey(),
-				(int)bestAttack.affectedUnits[0]->getCount(),
-				(int)bestAttack.from,
-				(int)bestAttack.attack.attacker->getPosition().hex,
-				bestAttack.attack.chargeDistance,
-				bestAttack.attack.attacker->speed(0, true),
-				bestAttack.defenderDamageReduce,
-				bestAttack.attackerDamageReduce, bestAttack.attackValue()
-			);
-
-			if (moveTarget.scorePerTurn <= score)
-			{
-				if(evaluationResult.wait)
-				{
-					return BattleAction::makeWait(stack);
-				}
-				else if(bestAttack.attack.shooting)
-				{
-					activeActionMade = true;
-					return BattleAction::makeShotAttack(stack, bestAttack.attack.defender);
-				}
-				else
-				{
-					activeActionMade = true;
-					return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from);
-				}
-			}
-		}
-	}
-
-	//ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code.
-	if(moveTarget.scorePerTurn > score)
-	{
-		score = moveTarget.score;
-		cachedAttack = moveTarget.cachedAttack;
-		cachedScore = score;
-
-		if(stack->waited())
-		{
-			return goTowardsNearest(stack, moveTarget.positions);
-		}
-		else
-		{
-			return BattleAction::makeWait(stack);
-		}
-	}
-
-	if(score <= EvaluationResult::INEFFECTIVE_SCORE
-		&& !stack->hasBonusOfType(BonusType::FLYING)
-		&& stack->unitSide() == BattleSide::ATTACKER
-		&& cb->battleGetSiegeLevel() >= CGTownInstance::CITADEL)
-	{
-		auto brokenWallMoat = getBrokenWallMoatHexes();
-
-		if(brokenWallMoat.size())
-		{
-			activeActionMade = true;
-
-			if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition()))
-				return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT));
-			else
-				return goTowardsNearest(stack, brokenWallMoat);
-		}
-	}
-
-	return BattleAction::makeDefend(stack);
-}
-
 void CBattleAI::yourTacticPhase(int distance)
 {
 	cb->battleMakeTacticAction(BattleAction::makeEndOFTacticPhase(cb->battleGetTacticsSide()));
 }
 
-uint64_t timeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock> start)
-{
-	auto end = std::chrono::high_resolution_clock::now();
-
-	return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
-}
-
 void CBattleAI::activeStack( const CStack * stack )
 {
 	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
 
+	auto timeElapsed = [](std::chrono::time_point<std::chrono::high_resolution_clock> start) -> uint64_t
+	{
+		auto end = std::chrono::high_resolution_clock::now();
+
+		return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
+	};
+
 	BattleAction result = BattleAction::makeDefend(stack);
 	setCbc(cb); //TODO: make solid sure that AIs always use their callbacks (need to take care of event handlers too)
 
@@ -332,12 +110,19 @@ void CBattleAI::activeStack( const CStack * stack )
 			return;
 		}
 
-		BattleEvaluator evaluator(env, cb, stack, playerID, side);
+		BattleEvaluator evaluator(env, cb, stack, playerID, side, strengthRatio);
 
 		result = evaluator.selectStackAction(stack);
 
-		if(evaluator.attemptCastingSpell(stack))
-			return;
+		if(!skipCastUntilNextBattle && evaluator.canCastSpell())
+		{
+			auto spelCasted = evaluator.attemptCastingSpell(stack);
+
+			if(spelCasted)
+				return;
+			
+			skipCastUntilNextBattle = true;
+		}
 
 		logAi->trace("Spellcast attempt completed in %lld", timeElapsed(start));
 
@@ -370,103 +155,6 @@ void CBattleAI::activeStack( const CStack * stack )
 	cb->battleMakeUnitAction(result);
 }
 
-BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes)
-{
-	auto reachability = cb->getReachability(stack);
-	auto avHexes = cb->battleGetAvailableHexes(reachability, stack, false);
-
-	if(!avHexes.size() || !hexes.size()) //we are blocked or dest is blocked
-	{
-		return BattleAction::makeDefend(stack);
-	}
-
-	std::sort(hexes.begin(), hexes.end(), [&](BattleHex h1, BattleHex h2) -> bool
-	{
-		return reachability.distances[h1] < reachability.distances[h2];
-	});
-
-	for(auto hex : hexes)
-	{
-		if(vstd::contains(avHexes, hex))
-		{
-			return BattleAction::makeMove(stack, hex);
-		}
-
-		if(stack->coversPos(hex))
-		{
-			logAi->warn("Warning: already standing on neighbouring tile!");
-			//We shouldn't even be here...
-			return BattleAction::makeDefend(stack);
-		}
-	}
-
-	BattleHex bestNeighbor = hexes.front();
-
-	if(reachability.distances[bestNeighbor] > GameConstants::BFIELD_SIZE)
-	{
-		return BattleAction::makeDefend(stack);
-	}
-
-	scoreEvaluator.updateReachabilityMap(hb);
-
-	if(stack->hasBonusOfType(BonusType::FLYING))
-	{
-		std::set<BattleHex> obstacleHexes;
-
-		auto insertAffected = [](const CObstacleInstance & spellObst, std::set<BattleHex> obstacleHexes) {
-			auto affectedHexes = spellObst.getAffectedTiles();
-			obstacleHexes.insert(affectedHexes.cbegin(), affectedHexes.cend());
-		};
-
-		const auto & obstacles = hb->battleGetAllObstacles();
-
-		for (const auto & obst: obstacles) {
-
-			if(obst->triggersEffects())
-			{
-				auto triggerAbility =  VLC->spells()->getById(obst->getTrigger());
-				auto triggerIsNegative = triggerAbility->isNegative() || triggerAbility->isDamage();
-
-				if(triggerIsNegative)
-					insertAffected(*obst, obstacleHexes);
-			}
-		}
-		// Flying stack doesn't go hex by hex, so we can't backtrack using predecessors.
-		// We just check all available hexes and pick the one closest to the target.
-		auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int
-		{
-			const int NEGATIVE_OBSTACLE_PENALTY = 100; // avoid landing on negative obstacle (moat, fire wall, etc)
-			const int BLOCKED_STACK_PENALTY = 100; // avoid landing on moat
-
-			auto distance = BattleHex::getDistance(bestNeighbor, hex);
-
-			if(vstd::contains(obstacleHexes, hex))
-				distance += NEGATIVE_OBSTACLE_PENALTY;
-
-			return scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, hex) ? BLOCKED_STACK_PENALTY + distance : distance;
-		});
-
-		return BattleAction::makeMove(stack, *nearestAvailableHex);
-	}
-	else
-	{
-		BattleHex currentDest = bestNeighbor;
-		while(1)
-		{
-			if(!currentDest.isValid())
-			{
-				return BattleAction::makeDefend(stack);
-			}
-
-			if(vstd::contains(avHexes, currentDest)
-				&& !scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, currentDest))
-				return BattleAction::makeMove(stack, currentDest);
-
-			currentDest = reachability.predecessors[currentDest];
-		}
-	}
-}
-
 BattleAction CBattleAI::useCatapult(const CStack * stack)
 {
 	BattleAction attack;
@@ -515,356 +203,19 @@ BattleAction CBattleAI::useCatapult(const CStack * stack)
 	return attack;
 }
 
-<<<<<<< HEAD
-bool CBattleAI::attemptCastingSpell()
-=======
-void BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
->>>>>>> ea22737e9 (BattleAI: damage cache and switch to different model of spells evaluation)
-{
-	auto hero = cb->battleGetMyHero();
-	if(!hero)
-		return false;
-
-	if(cb->battleCanCastSpell(hero, spells::Mode::HERO) != ESpellCastProblem::OK)
-		return false;
-
-	LOGL("Casting spells sounds like fun. Let's see...");
-	//Get all spells we can cast
-	std::vector<const CSpell*> possibleSpells;
-	vstd::copy_if(VLC->spellh->objects, std::back_inserter(possibleSpells), [hero, this](const CSpell *s) -> bool
-	{
-		return s->canBeCast(cb.get(), spells::Mode::HERO, hero);
-	});
-	LOGFL("I can cast %d spells.", possibleSpells.size());
-
-	vstd::erase_if(possibleSpells, [](const CSpell *s)
-	{
-		return spellType(s) != SpellTypes::BATTLE || s->getTargetType() == spells::AimType::LOCATION;
-	});
-
-	LOGFL("I know how %d of them works.", possibleSpells.size());
-
-	//Get possible spell-target pairs
-	std::vector<PossibleSpellcast> possibleCasts;
-	for(auto spell : possibleSpells)
-	{
-		spells::BattleCast temp(cb.get(), hero, spells::Mode::HERO, spell);
-
-		if(spell->getTargetType() == spells::AimType::LOCATION)
-			continue;
-		
-		const bool FAST = true;
-
-		for(auto & target : temp.findPotentialTargets(FAST))
-		{
-			PossibleSpellcast ps;
-			ps.dest = target;
-			ps.spell = spell;
-			possibleCasts.push_back(ps);
-		}
-	}
-	LOGFL("Found %d spell-target combinations.", possibleCasts.size());
-	if(possibleCasts.empty())
-		return false;
-
-	using ValueMap = PossibleSpellcast::ValueMap;
-
-	auto evaluateQueue = [&](ValueMap & values, const std::vector<battle::Units> & queue, std::shared_ptr<HypotheticBattle> state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool
-	{
-		bool firstRound = true;
-		bool enemyHadTurn = false;
-		size_t ourTurnSpan = 0;
-
-		bool stop = false;
-
-		for(auto & round : queue)
-		{
-			if(!firstRound)
-				state->nextRound(0);//todo: set actual value?
-			for(auto unit : round)
-			{
-				if(!vstd::contains(values, unit->unitId()))
-					values[unit->unitId()] = 0;
-
-				if(!unit->alive())
-					continue;
-
-				if(state->battleGetOwner(unit) != playerID)
-				{
-					enemyHadTurn = true;
-
-					if(!firstRound || state->battleCastSpells(unit->unitSide()) == 0)
-					{
-						//enemy could counter our spell at this point
-						//anyway, we do not know what enemy will do
-						//just stop evaluation
-						stop = true;
-						break;
-					}
-				}
-				else if(!enemyHadTurn)
-				{
-					ourTurnSpan++;
-				}
-
-				state->nextTurn(unit->unitId());
-
-				PotentialTargets pt(unit, damageCache, state);
-
-				if(!pt.possibleAttacks.empty())
-				{
-					AttackPossibility ap = pt.bestAction();
-
-					auto swb = state->getForUpdate(unit->unitId());
-					*swb = *ap.attackerState;
-
-					if(ap.defenderDamageReduce > 0)
-						swb->removeUnitBonus(Bonus::UntilAttack);
-					if(ap.attackerDamageReduce > 0)
-						swb->removeUnitBonus(Bonus::UntilBeingAttacked);
-
-					for(auto affected : ap.affectedUnits)
-					{
-						swb = state->getForUpdate(affected->unitId());
-						*swb = *affected;
-
-						if(ap.defenderDamageReduce > 0)
-							swb->removeUnitBonus(Bonus::UntilBeingAttacked);
-						if(ap.attackerDamageReduce > 0 && ap.attack.defender->unitId() == affected->unitId())
-							swb->removeUnitBonus(Bonus::UntilAttack);
-					}
-				}
-
-				auto bav = pt.bestActionValue();
-
-				//best action is from effective owner`s point if view, we need to convert to our point if view
-				if(state->battleGetOwner(unit) != playerID)
-					bav = -bav;
-				values[unit->unitId()] += bav;
-			}
-
-			firstRound = false;
-
-			if(stop)
-				break;
-		}
-
-		if(enemyHadTurnOut)
-			*enemyHadTurnOut = enemyHadTurn;
-
-		return ourTurnSpan >= minTurnSpan;
-	};
-
-	ValueMap valueOfStack;
-	ValueMap healthOfStack;
-
-	TStacks all = cb->battleGetAllStacks(false);
-
-	size_t ourRemainingTurns = 0;
-
-	for(auto unit : all)
-	{
-		healthOfStack[unit->unitId()] = unit->getAvailableHealth();
-		valueOfStack[unit->unitId()] = 0;
-
-		if(cb->battleGetOwner(unit) == playerID && unit->canMove() && !unit->moved())
-			ourRemainingTurns++;
-	}
-
-	LOGFL("I have %d turns left in this round", ourRemainingTurns);
-
-	const bool castNow = ourRemainingTurns <= 1;
-
-	if(castNow)
-		print("I should try to cast a spell now");
-	else
-		print("I could wait better moment to cast a spell");
-
-	auto amount = all.size();
-
-	std::vector<battle::Units> turnOrder;
-
-	cb->battleGetTurnOrder(turnOrder, amount, 2); //no more than 1 turn after current, each unit at least once
-
-	{
-		bool enemyHadTurn = false;
-
-		auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
-
-		evaluateQueue(valueOfStack, turnOrder, state, 0, &enemyHadTurn);
-
-		if(!enemyHadTurn)
-		{
-			auto battleIsFinishedOpt = state->battleIsFinished();
-
-			if(battleIsFinishedOpt)
-			{
-				print("No need to cast a spell. Battle will finish soon.");
-				return false;
-			}
-		}
-	}
-
-	CStopWatch timer;
-
-	tbb::parallel_for(tbb::blocked_range<size_t>(0, possibleCasts.size()), [&](const tbb::blocked_range<size_t> & r)
-		{
-			for(auto i = r.begin(); i != r.end(); i++)
-			{
-				auto & ps = possibleCasts[i];
-				auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
-
-				spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell);
-				cast.castEval(state->getServerCallback(), ps.dest);
-
-				auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool { return true; });
-
-				auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool
-					{
-						auto original = cb->battleGetUnitByID(u->unitId());
-						return  !original || u->speed() != original->speed();
-					});
-
-				DamageCache innerCache(&damageCache);
-				innerCache.buildDamageCache(state, side);
-
-				if(needFullEval || !cachedAttack)
-				{
-					PotentialTargets innerTargets(activeStack, damageCache, state);
-					BattleExchangeEvaluator innerEvaluator(state, env);
-
-					if(!innerTargets.possibleAttacks.empty())
-					{
-						innerEvaluator.updateReachabilityMap(state);
-
-						auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state);
-
-						ps.value = newStackAction.score;
-					}
-					else
-					{
-						ps.value = 0;
-					}
-				}
-				else
-				{
-					ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state);
-				}
-
-				for(auto unit : allUnits)
-				{
-					auto newHealth = unit->getAvailableHealth();
-					auto oldHealth = healthOfStack[unit->unitId()];
-
-					if(oldHealth != newHealth)
-					{
-						auto damage = std::abs(oldHealth - newHealth);
-						auto originalDefender = cb->battleGetUnitByID(unit->unitId());
-						auto dpsReduce = AttackPossibility::calculateDamageReduce(nullptr, originalDefender ? originalDefender : unit, damage, innerCache, state);
-						auto ourUnit = unit->unitSide() == side ? 1 : -1;
-						auto goodEffect = newHealth > oldHealth ? 1 : -1;
-
-						ps.value += ourUnit * goodEffect * dpsReduce;
-					}
-				}
-			}
-		});
-
-	LOGFL("Evaluation took %d ms", timer.getDiff());
-
-	auto pscValue = [](const PossibleSpellcast &ps) -> int64_t
-	{
-		return ps.value;
-	};
-	auto castToPerform = *vstd::maxElementByFun(possibleCasts, pscValue);
-
-	if(castToPerform.value > cachedScore)
-	{
-		LOGFL("Best spell is %s (value %d). Will cast.", castToPerform.spell->getNameTranslated() % castToPerform.value);
-		BattleAction spellcast;
-		spellcast.actionType = EActionType::HERO_SPELL;
-		spellcast.spell = castToPerform.spell->getId();
-		spellcast.setTarget(castToPerform.dest);
-		spellcast.side = side;
-		spellcast.stackNumber = (!side) ? -1 : -2;
-		cb->battleMakeSpellAction(spellcast);
-<<<<<<< HEAD
-		movesSkippedByDefense = 0;
-		return true;
-=======
-		activeActionMade = true;
->>>>>>> ea22737e9 (BattleAI: damage cache and switch to different model of spells evaluation)
-	}
-	else
-	{
-		LOGFL("Best spell is %s. But it is actually useless (value %d).", castToPerform.spell->getNameTranslated() % castToPerform.value);
-		return false;
-	}
-}
-
-//Below method works only for offensive spells
-void BattleEvaluator::evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps)
-{
-	using ValueMap = PossibleSpellcast::ValueMap;
-
-	RNGStub rngStub;
-	HypotheticBattle state(env.get(), cb);
-	TStacks all = cb->battleGetAllStacks(false);
-
-	ValueMap healthOfStack;
-	ValueMap newHealthOfStack;
-
-	for(auto unit : all)
-	{
-		healthOfStack[unit->unitId()] = unit->getAvailableHealth();
-	}
-
-	spells::BattleCast cast(&state, stack, spells::Mode::CREATURE_ACTIVE, ps.spell);
-	cast.castEval(state.getServerCallback(), ps.dest);
-
-	for(auto unit : all)
-	{
-		auto unitId = unit->unitId();
-		auto localUnit = state.battleGetUnitByID(unitId);
-		newHealthOfStack[unitId] = localUnit->getAvailableHealth();
-	}
-
-	int64_t totalGain = 0;
-
-	for(auto unit : all)
-	{
-		auto unitId = unit->unitId();
-		auto localUnit = state.battleGetUnitByID(unitId);
-
-		auto healthDiff = newHealthOfStack[unitId] - healthOfStack[unitId];
-
-		if(localUnit->unitOwner() != getCbc()->getPlayerID())
-			healthDiff = -healthDiff;
-
-		if(healthDiff < 0)
-		{
-			ps.value = -1;
-			return; //do not damage own units at all
-		}
-
-		totalGain += healthDiff;
-	}
-
-	ps.value = totalGain;
-}
-
 void CBattleAI::battleStart(const CCreatureSet *army1, const CCreatureSet *army2, int3 tile, const CGHeroInstance *hero1, const CGHeroInstance *hero2, bool Side, bool replayAllowed)
 {
 	LOG_TRACE(logAi);
 	side = Side;
-}
+	strengthRatio = static_cast<float>(army1->getArmyStrength()) / static_cast<float>(army2->getArmyStrength());
 
-void CBattleAI::print(const std::string &text) const
-{
-	logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text);
+	if(side == 1)
+		strengthRatio = 1 / strengthRatio;
+
+	skipCastUntilNextBattle = false;
 }
 
-void BattleEvaluator::print(const std::string & text) const
+void CBattleAI::print(const std::string &text) const
 {
 	logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text);
 }

+ 2 - 0
AI/BattleAI/BattleAI.h

@@ -62,6 +62,8 @@ class CBattleAI : public CBattleGameInterface
 	bool wasWaitingForRealize;
 	bool wasUnlockingGs;
 	int movesSkippedByDefense;
+	float strengthRatio;
+	bool skipCastUntilNextBattle;
 
 public:
 	CBattleAI();

+ 679 - 0
AI/BattleAI/BattleEvaluator.cpp

@@ -0,0 +1,679 @@
+/*
+ * BattleAI.cpp, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#include "StdInc.h"
+#include "BattleEvaluator.h"
+#include "BattleExchangeVariant.h"
+
+#include "StackWithBonuses.h"
+#include "EnemyInfo.h"
+#include "tbb/parallel_for.h"
+#include "../../lib/CStopWatch.h"
+#include "../../lib/CThreadHelper.h"
+#include "../../lib/mapObjects/CGTownInstance.h"
+#include "../../lib/spells/CSpellHandler.h"
+#include "../../lib/spells/ISpellMechanics.h"
+#include "../../lib/battle/BattleStateInfoForRetreat.h"
+#include "../../lib/battle/CObstacleInstance.h"
+#include "../../lib/battle/BattleAction.h"
+
+// TODO: remove
+// Eventually only IBattleInfoCallback and battle::Unit should be used,
+// CUnitState should be private and CStack should be removed completely
+#include "../../lib/CStack.h"
+
+#define LOGL(text) print(text)
+#define LOGFL(text, formattingEl) print(boost::str(boost::format(text) % formattingEl))
+
+enum class SpellTypes
+{
+	ADVENTURE, BATTLE, OTHER
+};
+
+SpellTypes spellType(const CSpell * spell)
+{
+	if(!spell->isCombat() || spell->isCreatureAbility())
+		return SpellTypes::OTHER;
+
+	if(spell->isOffensive() || spell->hasEffects() || spell->hasBattleEffects())
+		return SpellTypes::BATTLE;
+
+	return SpellTypes::OTHER;
+}
+
+std::vector<BattleHex> BattleEvaluator::getBrokenWallMoatHexes() const
+{
+	std::vector<BattleHex> result;
+
+	for(EWallPart wallPart : { EWallPart::BOTTOM_WALL, EWallPart::BELOW_GATE, EWallPart::OVER_GATE, EWallPart::UPPER_WALL })
+	{
+		auto state = cb->battleGetWallState(wallPart);
+
+		if(state != EWallState::DESTROYED)
+			continue;
+
+		auto wallHex = cb->wallPartToBattleHex((EWallPart)wallPart);
+		auto moatHex = wallHex.cloneInDirection(BattleHex::LEFT);
+
+		result.push_back(moatHex);
+	}
+
+	return result;
+}
+
+std::optional<PossibleSpellcast> BattleEvaluator::findBestCreatureSpell(const CStack *stack)
+{
+	//TODO: faerie dragon type spell should be selected by server
+	SpellID creatureSpellToCast = cb->battleGetRandomStackSpell(CRandomGenerator::getDefault(), stack, CBattleInfoCallback::RANDOM_AIMED);
+	if(stack->hasBonusOfType(BonusType::SPELLCASTER) && stack->canCast() && creatureSpellToCast != SpellID::NONE)
+	{
+		const CSpell * spell = creatureSpellToCast.toSpell();
+
+		if(spell->canBeCast(getCbc().get(), spells::Mode::CREATURE_ACTIVE, stack))
+		{
+			std::vector<PossibleSpellcast> possibleCasts;
+			spells::BattleCast temp(getCbc().get(), stack, spells::Mode::CREATURE_ACTIVE, spell);
+			for(auto & target : temp.findPotentialTargets())
+			{
+				PossibleSpellcast ps;
+				ps.dest = target;
+				ps.spell = spell;
+				evaluateCreatureSpellcast(stack, ps);
+				possibleCasts.push_back(ps);
+			}
+
+			std::sort(possibleCasts.begin(), possibleCasts.end(), [&](const PossibleSpellcast & lhs, const PossibleSpellcast & rhs) { return lhs.value > rhs.value; });
+			if(!possibleCasts.empty() && possibleCasts.front().value > 0)
+			{
+				return possibleCasts.front();
+			}
+		}
+	}
+	return std::nullopt;
+}
+
+BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
+{
+	//evaluate casting spell for spellcasting stack
+	std::optional<PossibleSpellcast> bestSpellcast = findBestCreatureSpell(stack);
+
+	auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb);
+	auto score = EvaluationResult::INEFFECTIVE_SCORE;
+
+	if(targets->possibleAttacks.empty() && bestSpellcast.has_value())
+	{
+		activeActionMade = true;
+		return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
+	}
+
+	if(!targets->possibleAttacks.empty())
+	{
+#if BATTLE_TRACE_LEVEL>=1
+		logAi->trace("Evaluating attack for %s", stack->getDescription());
+#endif
+
+		auto evaluationResult = scoreEvaluator.findBestTarget(stack, *targets, damageCache, hb);
+		auto & bestAttack = evaluationResult.bestAttack;
+
+		cachedAttack = bestAttack;
+		cachedScore = evaluationResult.score;
+
+		//TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc.
+		if(bestSpellcast.has_value() && bestSpellcast->value > bestAttack.damageDiff())
+		{
+			// return because spellcast value is damage dealt and score is dps reduce
+			activeActionMade = true;
+			return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id);
+		}
+
+		if(evaluationResult.score > score)
+		{
+			score = evaluationResult.score;
+
+			logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld",
+				bestAttack.attackerState->unitType()->getJsonKey(),
+				bestAttack.affectedUnits[0]->unitType()->getJsonKey(),
+				(int)bestAttack.affectedUnits[0]->getCount(),
+				(int)bestAttack.from,
+				(int)bestAttack.attack.attacker->getPosition().hex,
+				bestAttack.attack.chargeDistance,
+				bestAttack.attack.attacker->speed(0, true),
+				bestAttack.defenderDamageReduce,
+				bestAttack.attackerDamageReduce, bestAttack.attackValue()
+			);
+
+			if (moveTarget.scorePerTurn <= score)
+			{
+				if(evaluationResult.wait)
+				{
+					return BattleAction::makeWait(stack);
+				}
+				else if(bestAttack.attack.shooting)
+				{
+					activeActionMade = true;
+					return BattleAction::makeShotAttack(stack, bestAttack.attack.defender);
+				}
+				else
+				{
+					if(bestAttack.collateralDamageReduce
+						&& bestAttack.collateralDamageReduce >= bestAttack.defenderDamageReduce / 2
+						&& score < 0)
+					{
+						return BattleAction::makeDefend(stack);
+					}
+					else
+					{
+						activeActionMade = true;
+						return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from);
+					}
+				}
+			}
+		}
+	}
+
+	//ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code.
+	if(moveTarget.scorePerTurn > score)
+	{
+		score = moveTarget.score;
+		cachedAttack = moveTarget.cachedAttack;
+		cachedScore = score;
+
+		if(stack->waited())
+		{
+			return goTowardsNearest(stack, moveTarget.positions);
+		}
+		else
+		{
+			return BattleAction::makeWait(stack);
+		}
+	}
+
+	if(score <= EvaluationResult::INEFFECTIVE_SCORE
+		&& !stack->hasBonusOfType(BonusType::FLYING)
+		&& stack->unitSide() == BattleSide::ATTACKER
+		&& cb->battleGetSiegeLevel() >= CGTownInstance::CITADEL)
+	{
+		auto brokenWallMoat = getBrokenWallMoatHexes();
+
+		if(brokenWallMoat.size())
+		{
+			activeActionMade = true;
+
+			if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition()))
+				return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT));
+			else
+				return goTowardsNearest(stack, brokenWallMoat);
+		}
+	}
+
+	return BattleAction::makeDefend(stack);
+}
+
+uint64_t timeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock> start)
+{
+	auto end = std::chrono::high_resolution_clock::now();
+
+	return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
+}
+
+BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes)
+{
+	auto reachability = cb->getReachability(stack);
+	auto avHexes = cb->battleGetAvailableHexes(reachability, stack, false);
+
+	if(!avHexes.size() || !hexes.size()) //we are blocked or dest is blocked
+	{
+		return BattleAction::makeDefend(stack);
+	}
+
+	std::sort(hexes.begin(), hexes.end(), [&](BattleHex h1, BattleHex h2) -> bool
+	{
+		return reachability.distances[h1] < reachability.distances[h2];
+	});
+
+	for(auto hex : hexes)
+	{
+		if(vstd::contains(avHexes, hex))
+		{
+			return BattleAction::makeMove(stack, hex);
+		}
+
+		if(stack->coversPos(hex))
+		{
+			logAi->warn("Warning: already standing on neighbouring tile!");
+			//We shouldn't even be here...
+			return BattleAction::makeDefend(stack);
+		}
+	}
+
+	BattleHex bestNeighbor = hexes.front();
+
+	if(reachability.distances[bestNeighbor] > GameConstants::BFIELD_SIZE)
+	{
+		return BattleAction::makeDefend(stack);
+	}
+
+	scoreEvaluator.updateReachabilityMap(hb);
+
+	if(stack->hasBonusOfType(BonusType::FLYING))
+	{
+		std::set<BattleHex> obstacleHexes;
+
+		auto insertAffected = [](const CObstacleInstance & spellObst, std::set<BattleHex> obstacleHexes) {
+			auto affectedHexes = spellObst.getAffectedTiles();
+			obstacleHexes.insert(affectedHexes.cbegin(), affectedHexes.cend());
+		};
+
+		const auto & obstacles = hb->battleGetAllObstacles();
+
+		for (const auto & obst: obstacles) {
+
+			if(obst->triggersEffects())
+			{
+				auto triggerAbility =  VLC->spells()->getById(obst->getTrigger());
+				auto triggerIsNegative = triggerAbility->isNegative() || triggerAbility->isDamage();
+
+				if(triggerIsNegative)
+					insertAffected(*obst, obstacleHexes);
+			}
+		}
+		// Flying stack doesn't go hex by hex, so we can't backtrack using predecessors.
+		// We just check all available hexes and pick the one closest to the target.
+		auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int
+		{
+			const int NEGATIVE_OBSTACLE_PENALTY = 100; // avoid landing on negative obstacle (moat, fire wall, etc)
+			const int BLOCKED_STACK_PENALTY = 100; // avoid landing on moat
+
+			auto distance = BattleHex::getDistance(bestNeighbor, hex);
+
+			if(vstd::contains(obstacleHexes, hex))
+				distance += NEGATIVE_OBSTACLE_PENALTY;
+
+			return scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, hex) ? BLOCKED_STACK_PENALTY + distance : distance;
+		});
+
+		return BattleAction::makeMove(stack, *nearestAvailableHex);
+	}
+	else
+	{
+		BattleHex currentDest = bestNeighbor;
+		while(1)
+		{
+			if(!currentDest.isValid())
+			{
+				return BattleAction::makeDefend(stack);
+			}
+
+			if(vstd::contains(avHexes, currentDest)
+				&& !scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, currentDest))
+				return BattleAction::makeMove(stack, currentDest);
+
+			currentDest = reachability.predecessors[currentDest];
+		}
+	}
+}
+
+bool BattleEvaluator::canCastSpell()
+{
+	auto hero = cb->battleGetMyHero();
+	if(!hero)
+		return false;
+
+	return cb->battleCanCastSpell(hero, spells::Mode::HERO) == ESpellCastProblem::OK;
+}
+
+bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
+{
+	auto hero = cb->battleGetMyHero();
+	if(!hero)
+		return false;
+
+	LOGL("Casting spells sounds like fun. Let's see...");
+	//Get all spells we can cast
+	std::vector<const CSpell*> possibleSpells;
+	vstd::copy_if(VLC->spellh->objects, std::back_inserter(possibleSpells), [hero, this](const CSpell *s) -> bool
+	{
+		return s->canBeCast(cb.get(), spells::Mode::HERO, hero);
+	});
+	LOGFL("I can cast %d spells.", possibleSpells.size());
+
+	vstd::erase_if(possibleSpells, [](const CSpell *s)
+	{
+		return spellType(s) != SpellTypes::BATTLE || s->getTargetType() == spells::AimType::LOCATION;
+	});
+
+	LOGFL("I know how %d of them works.", possibleSpells.size());
+
+	//Get possible spell-target pairs
+	std::vector<PossibleSpellcast> possibleCasts;
+	for(auto spell : possibleSpells)
+	{
+		spells::BattleCast temp(cb.get(), hero, spells::Mode::HERO, spell);
+
+		if(spell->getTargetType() == spells::AimType::LOCATION)
+			continue;
+		
+		const bool FAST = true;
+
+		for(auto & target : temp.findPotentialTargets(FAST))
+		{
+			PossibleSpellcast ps;
+			ps.dest = target;
+			ps.spell = spell;
+			possibleCasts.push_back(ps);
+		}
+	}
+	LOGFL("Found %d spell-target combinations.", possibleCasts.size());
+	if(possibleCasts.empty())
+		return false;
+
+	using ValueMap = PossibleSpellcast::ValueMap;
+
+	auto evaluateQueue = [&](ValueMap & values, const std::vector<battle::Units> & queue, std::shared_ptr<HypotheticBattle> state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool
+	{
+		bool firstRound = true;
+		bool enemyHadTurn = false;
+		size_t ourTurnSpan = 0;
+
+		bool stop = false;
+
+		for(auto & round : queue)
+		{
+			if(!firstRound)
+				state->nextRound(0);//todo: set actual value?
+			for(auto unit : round)
+			{
+				if(!vstd::contains(values, unit->unitId()))
+					values[unit->unitId()] = 0;
+
+				if(!unit->alive())
+					continue;
+
+				if(state->battleGetOwner(unit) != playerID)
+				{
+					enemyHadTurn = true;
+
+					if(!firstRound || state->battleCastSpells(unit->unitSide()) == 0)
+					{
+						//enemy could counter our spell at this point
+						//anyway, we do not know what enemy will do
+						//just stop evaluation
+						stop = true;
+						break;
+					}
+				}
+				else if(!enemyHadTurn)
+				{
+					ourTurnSpan++;
+				}
+
+				state->nextTurn(unit->unitId());
+
+				PotentialTargets pt(unit, damageCache, state);
+
+				if(!pt.possibleAttacks.empty())
+				{
+					AttackPossibility ap = pt.bestAction();
+
+					auto swb = state->getForUpdate(unit->unitId());
+					*swb = *ap.attackerState;
+
+					if(ap.defenderDamageReduce > 0)
+						swb->removeUnitBonus(Bonus::UntilAttack);
+					if(ap.attackerDamageReduce > 0)
+						swb->removeUnitBonus(Bonus::UntilBeingAttacked);
+
+					for(auto affected : ap.affectedUnits)
+					{
+						swb = state->getForUpdate(affected->unitId());
+						*swb = *affected;
+
+						if(ap.defenderDamageReduce > 0)
+							swb->removeUnitBonus(Bonus::UntilBeingAttacked);
+						if(ap.attackerDamageReduce > 0 && ap.attack.defender->unitId() == affected->unitId())
+							swb->removeUnitBonus(Bonus::UntilAttack);
+					}
+				}
+
+				auto bav = pt.bestActionValue();
+
+				//best action is from effective owner`s point if view, we need to convert to our point if view
+				if(state->battleGetOwner(unit) != playerID)
+					bav = -bav;
+				values[unit->unitId()] += bav;
+			}
+
+			firstRound = false;
+
+			if(stop)
+				break;
+		}
+
+		if(enemyHadTurnOut)
+			*enemyHadTurnOut = enemyHadTurn;
+
+		return ourTurnSpan >= minTurnSpan;
+	};
+
+	ValueMap valueOfStack;
+	ValueMap healthOfStack;
+
+	TStacks all = cb->battleGetAllStacks(false);
+
+	size_t ourRemainingTurns = 0;
+
+	for(auto unit : all)
+	{
+		healthOfStack[unit->unitId()] = unit->getAvailableHealth();
+		valueOfStack[unit->unitId()] = 0;
+
+		if(cb->battleGetOwner(unit) == playerID && unit->canMove() && !unit->moved())
+			ourRemainingTurns++;
+	}
+
+	LOGFL("I have %d turns left in this round", ourRemainingTurns);
+
+	const bool castNow = ourRemainingTurns <= 1;
+
+	if(castNow)
+		print("I should try to cast a spell now");
+	else
+		print("I could wait better moment to cast a spell");
+
+	auto amount = all.size();
+
+	std::vector<battle::Units> turnOrder;
+
+	cb->battleGetTurnOrder(turnOrder, amount, 2); //no more than 1 turn after current, each unit at least once
+
+	{
+		bool enemyHadTurn = false;
+
+		auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
+
+		evaluateQueue(valueOfStack, turnOrder, state, 0, &enemyHadTurn);
+
+		if(!enemyHadTurn)
+		{
+			auto battleIsFinishedOpt = state->battleIsFinished();
+
+			if(battleIsFinishedOpt)
+			{
+				print("No need to cast a spell. Battle will finish soon.");
+				return false;
+			}
+		}
+	}
+
+	CStopWatch timer;
+
+	tbb::parallel_for(tbb::blocked_range<size_t>(0, possibleCasts.size()), [&](const tbb::blocked_range<size_t> & r)
+		{
+			for(auto i = r.begin(); i != r.end(); i++)
+			{
+				auto & ps = possibleCasts[i];
+				auto state = std::make_shared<HypotheticBattle>(env.get(), cb);
+
+				spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell);
+				cast.castEval(state->getServerCallback(), ps.dest);
+
+				auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool { return true; });
+
+				auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool
+					{
+						auto original = cb->battleGetUnitByID(u->unitId());
+						return  !original || u->speed() != original->speed();
+					});
+
+				DamageCache innerCache(&damageCache);
+				innerCache.buildDamageCache(state, side);
+
+				if(needFullEval || !cachedAttack)
+				{
+					PotentialTargets innerTargets(activeStack, damageCache, state);
+					BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio);
+
+					if(!innerTargets.possibleAttacks.empty())
+					{
+						innerEvaluator.updateReachabilityMap(state);
+
+						auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state);
+
+						ps.value = newStackAction.score;
+					}
+					else
+					{
+						ps.value = 0;
+					}
+				}
+				else
+				{
+					ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state);
+				}
+
+				for(auto unit : allUnits)
+				{
+					auto newHealth = unit->getAvailableHealth();
+					auto oldHealth = healthOfStack[unit->unitId()];
+
+					if(oldHealth != newHealth)
+					{
+						auto damage = std::abs(oldHealth - newHealth);
+						auto originalDefender = cb->battleGetUnitByID(unit->unitId());
+
+						auto dpsReduce = AttackPossibility::calculateDamageReduce(
+							nullptr,
+							originalDefender &&  originalDefender->alive() ? originalDefender : unit,
+							damage,
+							innerCache,
+							state);
+
+						auto ourUnit = unit->unitSide() == side ? 1 : -1;
+						auto goodEffect = newHealth > oldHealth ? 1 : -1;
+
+						if(ourUnit * goodEffect == 1)
+						{
+							if(ourUnit && goodEffect && (unit->isClone() || unit->isGhost() || !unit->unitSlot().validSlot()))
+								continue;
+
+							ps.value += dpsReduce * scoreEvaluator.getPositiveEffectMultiplier();
+						}
+						else
+							ps.value -= dpsReduce * scoreEvaluator.getNegativeEffectMultiplier();
+					}
+				}
+			}
+		});
+
+	LOGFL("Evaluation took %d ms", timer.getDiff());
+
+	auto pscValue = [](const PossibleSpellcast &ps) -> int64_t
+	{
+		return ps.value;
+	};
+	auto castToPerform = *vstd::maxElementByFun(possibleCasts, pscValue);
+
+	if(castToPerform.value > cachedScore)
+	{
+		LOGFL("Best spell is %s (value %d). Will cast.", castToPerform.spell->getNameTranslated() % castToPerform.value);
+		BattleAction spellcast;
+		spellcast.actionType = EActionType::HERO_SPELL;
+		spellcast.spell = castToPerform.spell->id;
+		spellcast.setTarget(castToPerform.dest);
+		spellcast.side = side;
+		spellcast.stackNumber = (!side) ? -1 : -2;
+		cb->battleMakeSpellAction(spellcast);
+		activeActionMade = true;
+
+		return true;
+	}
+
+	LOGFL("Best spell is %s. But it is actually useless (value %d).", castToPerform.spell->getNameTranslated() % castToPerform.value);
+
+	return false;
+}
+
+//Below method works only for offensive spells
+void BattleEvaluator::evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps)
+{
+	using ValueMap = PossibleSpellcast::ValueMap;
+
+	RNGStub rngStub;
+	HypotheticBattle state(env.get(), cb);
+	TStacks all = cb->battleGetAllStacks(false);
+
+	ValueMap healthOfStack;
+	ValueMap newHealthOfStack;
+
+	for(auto unit : all)
+	{
+		healthOfStack[unit->unitId()] = unit->getAvailableHealth();
+	}
+
+	spells::BattleCast cast(&state, stack, spells::Mode::CREATURE_ACTIVE, ps.spell);
+	cast.castEval(state.getServerCallback(), ps.dest);
+
+	for(auto unit : all)
+	{
+		auto unitId = unit->unitId();
+		auto localUnit = state.battleGetUnitByID(unitId);
+		newHealthOfStack[unitId] = localUnit->getAvailableHealth();
+	}
+
+	int64_t totalGain = 0;
+
+	for(auto unit : all)
+	{
+		auto unitId = unit->unitId();
+		auto localUnit = state.battleGetUnitByID(unitId);
+
+		auto healthDiff = newHealthOfStack[unitId] - healthOfStack[unitId];
+
+		if(localUnit->unitOwner() != getCbc()->getPlayerID())
+			healthDiff = -healthDiff;
+
+		if(healthDiff < 0)
+		{
+			ps.value = -1;
+			return; //do not damage own units at all
+		}
+
+		totalGain += healthDiff;
+	}
+
+	ps.value = totalGain;
+}
+
+void BattleEvaluator::print(const std::string & text) const
+{
+	logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text);
+}
+
+
+

+ 80 - 0
AI/BattleAI/BattleEvaluator.h

@@ -0,0 +1,80 @@
+/*
+ * BattleEvaluator.h, part of VCMI engine
+ *
+ * Authors: listed in file AUTHORS in main folder
+ *
+ * License: GNU General Public License v2.0 or later
+ * Full text of license available in license.txt file, in main folder
+ *
+ */
+#pragma once
+#include "../../lib/AI_Base.h"
+#include "../../lib/battle/ReachabilityInfo.h"
+#include "PossibleSpellcast.h"
+#include "PotentialTargets.h"
+#include "BattleExchangeVariant.h"
+
+VCMI_LIB_NAMESPACE_BEGIN
+
+class CSpell;
+
+VCMI_LIB_NAMESPACE_END
+
+class EnemyInfo;
+
+class BattleEvaluator
+{
+	std::unique_ptr<PotentialTargets> targets;
+	std::shared_ptr<HypotheticBattle> hb;
+	BattleExchangeEvaluator scoreEvaluator;
+	std::shared_ptr<CBattleCallback> cb;
+	std::shared_ptr<Environment> env;
+	bool activeActionMade = false;
+	std::optional<AttackPossibility> cachedAttack;
+	PlayerColor playerID;
+	int side;
+	int64_t cachedScore;
+	DamageCache damageCache;
+	float strengthRatio;
+
+public:
+	BattleAction selectStackAction(const CStack * stack);
+	bool attemptCastingSpell(const CStack * stack);
+	bool canCastSpell();
+	std::optional<PossibleSpellcast> findBestCreatureSpell(const CStack * stack);
+	BattleAction goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes);
+	std::vector<BattleHex> getBrokenWallMoatHexes() const;
+	void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only
+	void print(const std::string & text) const;
+
+	BattleEvaluator(
+		std::shared_ptr<Environment> env,
+		std::shared_ptr<CBattleCallback> cb,
+		const battle::Unit * activeStack,
+		PlayerColor playerID,
+		int side,
+		float strengthRatio)
+		:scoreEvaluator(cb, env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), strengthRatio(strengthRatio)
+	{
+		hb = std::make_shared<HypotheticBattle>(env.get(), cb);
+		damageCache.buildDamageCache(hb, side);
+
+		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
+		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
+	}
+
+	BattleEvaluator(
+		std::shared_ptr<Environment> env,
+		std::shared_ptr<CBattleCallback> cb,
+		std::shared_ptr<HypotheticBattle> hb,
+		DamageCache & damageCache,
+		const battle::Unit * activeStack,
+		PlayerColor playerID,
+		int side,
+		float strengthRatio)
+		:scoreEvaluator(cb, env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache), strengthRatio(strengthRatio)
+	{
+		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
+		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
+	}
+};

+ 30 - 20
AI/BattleAI/BattleExchangeVariant.cpp

@@ -41,7 +41,7 @@ int64_t BattleExchangeVariant::trackAttack(const AttackPossibility & ap, Hypothe
 		unitToUpdate->movedThisRound = affectedUnit->movedThisRound;
 	}
 
-	auto attackValue = ap.attackValue();
+	auto attackValue = ap.damageDiff(positiveEffectMultiplier, negativeEffectMultiplier);
 
 	dpsScore += attackValue;
 
@@ -97,11 +97,11 @@ int64_t BattleExchangeVariant::trackAttack(
 
 		if(isOurAttack)
 		{
-			dpsScore += defenderDamageReduce;
+			dpsScore += defenderDamageReduce * positiveEffectMultiplier;
 			attackerValue[attacker->unitId()].value += defenderDamageReduce;
 		}
 		else
-			dpsScore -= defenderDamageReduce;
+			dpsScore -= defenderDamageReduce * negativeEffectMultiplier;
 
 		defender->damage(attackDamage);
 		attacker->afterAttack(shooting, false);
@@ -125,12 +125,12 @@ int64_t BattleExchangeVariant::trackAttack(
 
 			if(isOurAttack)
 			{
-				dpsScore -= attackerDamageReduce;
+				dpsScore -= attackerDamageReduce * negativeEffectMultiplier;
 				attackerValue[attacker->unitId()].isRetalitated = true;
 			}
 			else
 			{
-				dpsScore += attackerDamageReduce;
+				dpsScore += attackerDamageReduce * positiveEffectMultiplier;
 				attackerValue[defender->unitId()].value += attackerDamageReduce;
 			}
 
@@ -206,7 +206,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
 	std::shared_ptr<HypotheticBattle> hb)
 {
 	MoveTarget result;
-	BattleExchangeVariant ev;
+	BattleExchangeVariant ev(getPositiveEffectMultiplier(), getNegativeEffectMultiplier());
 
 	if(targets.unreachableEnemies.empty())
 		return result;
@@ -353,6 +353,11 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getExchangeUnits(
 		}
 	}
 
+	vstd::erase_if(exchangeUnits, [&](const battle::Unit * u) -> bool
+		{
+			return !hb->battleGetUnitByID(u->unitId())->alive();
+		});
+
 	return exchangeUnits;
 }
 
@@ -376,7 +381,8 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 	std::vector<const battle::Unit *> ourStacks;
 	std::vector<const battle::Unit *> enemyStacks;
 
-	enemyStacks.push_back(ap.attack.defender);
+	if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
+		enemyStacks.push_back(ap.attack.defender);
 
 	std::vector<const battle::Unit *> exchangeUnits = getExchangeUnits(ap, targets, hb);
 
@@ -386,14 +392,7 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 	}
 
 	auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
-	BattleExchangeVariant v;
-	auto melleeAttackers = ourStacks;
-
-	vstd::removeDuplicates(melleeAttackers);
-	vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
-		{
-			return !cb->battleCanShoot(u);
-		});
+	BattleExchangeVariant v(getPositiveEffectMultiplier(), getNegativeEffectMultiplier());
 
 	for(auto unit : exchangeUnits)
 	{
@@ -403,12 +402,20 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 		bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
 		auto & attackerQueue = isOur ? ourStacks : enemyStacks;
 
-		if(!vstd::contains(attackerQueue, unit))
+		if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit))
 		{
 			attackerQueue.push_back(unit);
 		}
 	}
 
+	auto melleeAttackers = ourStacks;
+
+	vstd::removeDuplicates(melleeAttackers);
+	vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
+		{
+			return !cb->battleCanShoot(u);
+		});
+
 	bool canUseAp = true;
 
 	for(auto activeUnit : exchangeUnits)
@@ -430,7 +437,7 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 
 		auto targetUnit = ap.attack.defender;
 
-		if(!isOur || !exchangeBattle->getForUpdate(targetUnit->unitId())->alive())
+		if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
 		{
 			auto estimateAttack = [&](const battle::Unit * u) -> int64_t
 			{
@@ -459,7 +466,10 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 			{
 				auto reachable = exchangeBattle->battleGetUnitsIf([&](const battle::Unit * u) -> bool
 					{
-						if(!u->alive() || u->unitSide() == attacker->unitSide())
+						if(u->unitSide() == attacker->unitSide())
+							return false;
+
+						if(!exchangeBattle->getForUpdate(u->unitId())->alive())
 							return false;
 
 						return vstd::contains_if(reachabilityMap[u->getPosition()], [&](const battle::Unit * other) -> bool
@@ -506,12 +516,12 @@ int64_t BattleExchangeEvaluator::calculateExchange(
 
 		vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
 			{
-				return !exchangeBattle->getForUpdate(u->unitId())->alive();
+				return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
 			});
 
 		vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
 			{
-				return !exchangeBattle->getForUpdate(u->unitId())->alive();
+				return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
 			});
 	}
 

+ 14 - 2
AI/BattleAI/BattleExchangeVariant.h

@@ -59,7 +59,8 @@ struct EvaluationResult
 class BattleExchangeVariant
 {
 public:
-	BattleExchangeVariant(): dpsScore(0) {}
+	BattleExchangeVariant(float positiveEffectMultiplier, float negativeEffectMultiplier)
+		: dpsScore(0), positiveEffectMultiplier(positiveEffectMultiplier), negativeEffectMultiplier(negativeEffectMultiplier) {}
 
 	int64_t trackAttack(const AttackPossibility & ap, HypotheticBattle & state);
 
@@ -80,6 +81,8 @@ public:
 		std::map<BattleHex, battle::Units> & reachabilityMap);
 
 private:
+	float positiveEffectMultiplier;
+	float negativeEffectMultiplier;
 	int64_t dpsScore;
 	std::map<uint32_t, AttackerValue> attackerValue;
 };
@@ -91,9 +94,15 @@ private:
 	std::shared_ptr<Environment> env;
 	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
 	std::vector<battle::Units> turnOrder;
+	float negativeEffectMultiplier;
 
 public:
-	BattleExchangeEvaluator(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env): cb(cb), env(env) {}
+	BattleExchangeEvaluator(
+		std::shared_ptr<CBattleInfoCallback> cb,
+		std::shared_ptr<Environment> env,
+		float strengthRatio): cb(cb), env(env) {
+		negativeEffectMultiplier = strengthRatio;
+	}
 
 	EvaluationResult findBestTarget(
 		const battle::Unit * activeStack,
@@ -118,4 +127,7 @@ public:
 		std::shared_ptr<HypotheticBattle> hb);
 
 	std::vector<const battle::Unit *> getAdjacentUnits(const battle::Unit * unit);
+
+	float getPositiveEffectMultiplier() { return 1; }
+	float getNegativeEffectMultiplier() { return negativeEffectMultiplier; }
 };

+ 2 - 0
AI/BattleAI/CMakeLists.txt

@@ -1,6 +1,7 @@
 set(battleAI_SRCS
 		AttackPossibility.cpp
 		BattleAI.cpp
+		BattleEvaluator.cpp
 		common.cpp
 		EnemyInfo.cpp
 		PossibleSpellcast.cpp
@@ -15,6 +16,7 @@ set(battleAI_HEADERS
 
 		AttackPossibility.h
 		BattleAI.h
+		BattleEvaluator.h
 		common.h
 		EnemyInfo.h
 		PotentialTargets.h