浏览代码

Merge pull request #4507 from vcmi/fix-battle-ai

Battle AI: fix firewall, fix haste spellcast evaluation for waits and…
Ivan Savenko 1 年之前
父节点
当前提交
644d6f4529

+ 59 - 28
AI/BattleAI/BattleEvaluator.cpp

@@ -66,7 +66,6 @@ BattleEvaluator::BattleEvaluator(
 	damageCache.buildDamageCache(hb, side);
 
 	targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-	cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
 }
 
 BattleEvaluator::BattleEvaluator(
@@ -85,7 +84,6 @@ BattleEvaluator::BattleEvaluator(
 	damageCache(damageCache), strengthRatio(strengthRatio), battleID(battleID), simulationTurnsCount(simulationTurnsCount)
 {
 	targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-	cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
 }
 
 std::vector<BattleHex> BattleEvaluator::getBrokenWallMoatHexes() const
@@ -178,8 +176,10 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 		auto evaluationResult = scoreEvaluator.findBestTarget(stack, *targets, damageCache, hb);
 		auto & bestAttack = evaluationResult.bestAttack;
 
-		cachedAttack = bestAttack;
-		cachedScore = evaluationResult.score;
+		cachedAttack.ap = bestAttack;
+		cachedAttack.score = evaluationResult.score;
+		cachedAttack.turn = 0;
+		cachedAttack.waited = evaluationResult.wait;
 
 		//TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc.
 		if(bestSpellcast.has_value() && bestSpellcast->value > bestAttack.damageDiff())
@@ -239,8 +239,9 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 	if(moveTarget.score > score)
 	{
 		score = moveTarget.score;
-		cachedAttack = moveTarget.cachedAttack;
-		cachedScore = score;
+		cachedAttack.ap = moveTarget.cachedAttack;
+		cachedAttack.score = score;
+		cachedAttack.turn = moveTarget.turnsToRich;
 
 		if(stack->waited())
 		{
@@ -255,6 +256,8 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
 		}
 		else
 		{
+			cachedAttack.waited = true;
+
 			return BattleAction::makeWait(stack);
 		}
 	}
@@ -448,7 +451,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 
 	vstd::erase_if(possibleSpells, [](const CSpell *s)
 	{
-		return spellType(s) != SpellTypes::BATTLE || s->getTargetType() == spells::AimType::LOCATION;
+		return spellType(s) != SpellTypes::BATTLE;
 	});
 
 	LOGFL("I know how %d of them works.", possibleSpells.size());
@@ -459,9 +462,6 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 	{
 		spells::BattleCast temp(cb->getBattle(battleID).get(), hero, spells::Mode::HERO, spell);
 
-		if(spell->getTargetType() == spells::AimType::LOCATION)
-			continue;
-		
 		const bool FAST = true;
 
 		for(auto & target : temp.findPotentialTargets(FAST))
@@ -630,7 +630,15 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 				auto & ps = possibleCasts[i];
 
 #if BATTLE_TRACE_LEVEL >= 1
-				logAi->trace("Evaluating %s", ps.spell->getNameTranslated());
+				if(ps.dest.empty())
+					logAi->trace("Evaluating %s", ps.spell->getNameTranslated());
+				else
+				{
+					auto psFirst = ps.dest.front();
+					auto strWhere = psFirst.unitValue ? psFirst.unitValue->getDescription() : std::to_string(psFirst.hexValue.hex);
+
+					logAi->trace("Evaluating %s at %s", ps.spell->getNameTranslated(), strWhere);
+				}
 #endif
 
 				auto state = std::make_shared<HypotheticBattle>(env.get(), cb->getBattle(battleID));
@@ -648,9 +656,15 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 
 				DamageCache safeCopy = damageCache;
 				DamageCache innerCache(&safeCopy);
+
 				innerCache.buildDamageCache(state, side);
 
-				if(needFullEval || !cachedAttack)
+				if(cachedAttack.ap && cachedAttack.waited)
+				{
+					state->makeWait(activeStack);
+				}
+
+				if(needFullEval || !cachedAttack.ap)
 				{
 #if BATTLE_TRACE_LEVEL >= 1
 					logAi->trace("Full evaluation is started due to stack speed affected.");
@@ -659,29 +673,41 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 					PotentialTargets innerTargets(activeStack, innerCache, state);
 					BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio, simulationTurnsCount);
 
+					innerEvaluator.updateReachabilityMap(state);
+
+					auto moveTarget = innerEvaluator.findMoveTowardsUnreachable(activeStack, innerTargets, innerCache, state);
+
 					if(!innerTargets.possibleAttacks.empty())
 					{
-						innerEvaluator.updateReachabilityMap(state);
-
 						auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state);
 
-						ps.value = newStackAction.score;
+						ps.value = std::max(moveTarget.score, newStackAction.score);
 					}
 					else
 					{
-						ps.value = 0;
+						ps.value = moveTarget.score;
 					}
 				}
 				else
 				{
-					ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
+					auto updatedAttacker = state->getForUpdate(cachedAttack.ap->attack.attacker->unitId());
+					auto updatedDefender = state->getForUpdate(cachedAttack.ap->attack.defender->unitId());
+					auto updatedBai = BattleAttackInfo(
+						updatedAttacker.get(),
+						updatedDefender.get(),
+						cachedAttack.ap->attack.chargeDistance,
+						cachedAttack.ap->attack.shooting);
+
+					auto updatedAttack = AttackPossibility::evaluate(updatedBai, cachedAttack.ap->from, innerCache, state);
+
+					ps.value = scoreEvaluator.evaluateExchange(updatedAttack, cachedAttack.turn, *targets, innerCache, state);
 				}
 
 				for(const auto & unit : allUnits)
 				{
-					if (!unit->isValidTarget())
+					if(!unit->isValidTarget(true))
 						continue;
-					
+
 					auto newHealth = unit->getAvailableHealth();
 					auto oldHealth = vstd::find_or(healthOfStack, unit->unitId(), 0); // old health value may not exist for newly summoned units
 
@@ -692,7 +718,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 
 						auto dpsReduce = AttackPossibility::calculateDamageReduce(
 							nullptr,
-							originalDefender &&  originalDefender->alive() ? originalDefender : unit,
+							originalDefender && originalDefender->alive() ? originalDefender : unit,
 							damage,
 							innerCache,
 							state);
@@ -702,13 +728,18 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 
 						if(ourUnit * goodEffect == 1)
 						{
-							if(ourUnit && goodEffect && (unit->isClone() || unit->isGhost()))
+							auto isMagical = state->getForUpdate(unit->unitId())->summoned
+								|| unit->isClone()
+								|| unit->isGhost();
+
+							if(ourUnit && goodEffect && isMagical)
 								continue;
 
 							ps.value += dpsReduce * scoreEvaluator.getPositiveEffectMultiplier();
 						}
 						else
-							ps.value -= dpsReduce * scoreEvaluator.getNegativeEffectMultiplier();
+							// discourage AI making collateral damage with spells
+							ps.value -= 4 * dpsReduce * scoreEvaluator.getNegativeEffectMultiplier();
 
 #if BATTLE_TRACE_LEVEL >= 1
 						logAi->trace(
@@ -719,6 +750,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 #endif
 					}
 				}
+
 #if BATTLE_TRACE_LEVEL >= 1
 				logAi->trace("Total score: %2f", ps.value);
 #endif
@@ -729,13 +761,12 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 
 	LOGFL("Evaluation took %d ms", timer.getDiff());
 
-	auto pscValue = [](const PossibleSpellcast &ps) -> float
-	{
-		return ps.value;
-	};
-	auto castToPerform = *vstd::maxElementByFun(possibleCasts, pscValue);
+	auto castToPerform = *vstd::maxElementByFun(possibleCasts, [](const PossibleSpellcast & ps) -> float
+		{
+			return ps.value;
+		});
 
-	if(castToPerform.value > cachedScore)
+	if(castToPerform.value > cachedAttack.score && !vstd::isAlmostEqual(castToPerform.value, cachedAttack.score))
 	{
 		LOGFL("Best spell is %s (value %d). Will cast.", castToPerform.spell->getNameTranslated() % castToPerform.value);
 		BattleAction spellcast;

+ 9 - 2
AI/BattleAI/BattleEvaluator.h

@@ -22,6 +22,14 @@ VCMI_LIB_NAMESPACE_END
 
 class EnemyInfo;
 
+struct CachedAttack
+{
+	std::optional<AttackPossibility> ap;
+	float score = EvaluationResult::INEFFECTIVE_SCORE;
+	uint8_t turn = 255;
+	bool waited = false;
+};
+
 class BattleEvaluator
 {
 	std::unique_ptr<PotentialTargets> targets;
@@ -30,11 +38,10 @@ class BattleEvaluator
 	std::shared_ptr<CBattleCallback> cb;
 	std::shared_ptr<Environment> env;
 	bool activeActionMade = false;
-	std::optional<AttackPossibility> cachedAttack;
+	CachedAttack cachedAttack;
 	PlayerColor playerID;
 	BattleID battleID;
 	BattleSide side;
-	float cachedScore;
 	DamageCache damageCache;
 	float strengthRatio;
 	int simulationTurnsCount;

+ 1 - 3
AI/BattleAI/BattleExchangeVariant.cpp

@@ -219,9 +219,7 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
 
 		auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
 
-		hbWaited->resetActiveUnit();
-		hbWaited->getForUpdate(activeStack->unitId())->waiting = true;
-		hbWaited->getForUpdate(activeStack->unitId())->waitedThisTurn = true;
+		hbWaited->makeWait(activeStack);
 
 		updateReachabilityMap(hbWaited);
 

+ 9 - 1
AI/BattleAI/StackWithBonuses.cpp

@@ -502,10 +502,18 @@ ServerCallback * HypotheticBattle::getServerCallback()
 	return serverCallback.get();
 }
 
+void HypotheticBattle::makeWait(const battle::Unit * activeStack)
+{
+	auto unit = getForUpdate(activeStack->unitId());
+
+	resetActiveUnit();
+	unit->waiting = true;
+	unit->waitedThisTurn = true;
+}
+
 HypotheticBattle::HypotheticServerCallback::HypotheticServerCallback(HypotheticBattle * owner_)
 	:owner(owner_)
 {
-
 }
 
 void HypotheticBattle::HypotheticServerCallback::complain(const std::string & problem)

+ 2 - 0
AI/BattleAI/StackWithBonuses.h

@@ -164,6 +164,8 @@ public:
 
 	int64_t getTreeVersion() const;
 
+	void makeWait(const battle::Unit * activeStack);
+
 	void resetActiveUnit()
 	{
 		activeUnitId = -1;