Browse Source

BattleAI: optional simulation depth

Andrii Danylchenko 1 year ago
parent
commit
ff8a745a50

+ 10 - 1
AI/BattleAI/BattleAI.cpp

@@ -23,6 +23,7 @@
 #include "../../lib/battle/BattleAction.h"
 #include "../../lib/battle/BattleStateInfoForRetreat.h"
 #include "../../lib/battle/CObstacleInstance.h"
+#include "../../lib/StartInfo.h"
 #include "../../lib/CStack.h" // TODO: remove
                               // Eventually only IBattleInfoCallback and battle::Unit should be used,
                               // CUnitState should be private and CStack should be removed completely
@@ -122,6 +123,11 @@ static float getStrengthRatio(std::shared_ptr<CBattleInfoCallback> cb, BattleSid
 	return enemy == 0 ? 1.0f : static_cast<float>(our) / enemy;
 }
 
+int getSimulationTurnsCount(const StartInfo * startInfo)
+{
+	return startInfo->difficulty < 4 ? 2 : 10;
+}
+
 void CBattleAI::activeStack(const BattleID & battleID, const CStack * stack )
 {
 	LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
@@ -154,7 +160,10 @@ void CBattleAI::activeStack(const BattleID & battleID, const CStack * stack )
 		logAi->trace("Build evaluator and targets");
 #endif
 
-		BattleEvaluator evaluator(env, cb, stack, playerID, battleID, side, getStrengthRatio(cb->getBattle(battleID), side));
+		BattleEvaluator evaluator(
+			env, cb, stack, playerID, battleID, side, 
+			getStrengthRatio(cb->getBattle(battleID), side),
+			getSimulationTurnsCount(env->game()->getStartInfo()));
 
 		result = evaluator.selectStackAction(stack);
 

+ 40 - 1
AI/BattleAI/BattleEvaluator.cpp

@@ -49,6 +49,45 @@ SpellTypes spellType(const CSpell * spell)
 	return SpellTypes::OTHER;
 }
 
+BattleEvaluator::BattleEvaluator(
+	std::shared_ptr<Environment> env,
+	std::shared_ptr<CBattleCallback> cb,
+	const battle::Unit * activeStack,
+	PlayerColor playerID,
+	BattleID battleID,
+	BattleSide side,
+	float strengthRatio,
+	int simulationTurnsCount)
+	:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio, simulationTurnsCount),
+	cachedAttack(), playerID(playerID), side(side), env(env),
+	cb(cb), strengthRatio(strengthRatio), battleID(battleID), simulationTurnsCount(simulationTurnsCount)
+{
+	hb = std::make_shared<HypotheticBattle>(env.get(), cb->getBattle(battleID));
+	damageCache.buildDamageCache(hb, side);
+
+	targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
+	cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
+}
+
+BattleEvaluator::BattleEvaluator(
+	std::shared_ptr<Environment> env,
+	std::shared_ptr<CBattleCallback> cb,
+	std::shared_ptr<HypotheticBattle> hb,
+	DamageCache & damageCache,
+	const battle::Unit * activeStack,
+	PlayerColor playerID,
+	BattleID battleID,
+	BattleSide side,
+	float strengthRatio,
+	int simulationTurnsCount)
+	:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio, simulationTurnsCount),
+	cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb),
+	damageCache(damageCache), strengthRatio(strengthRatio), battleID(battleID), simulationTurnsCount(simulationTurnsCount)
+{
+	targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
+	cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
+}
+
 std::vector<BattleHex> BattleEvaluator::getBrokenWallMoatHexes() const
 {
 	std::vector<BattleHex> result;
@@ -618,7 +657,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
 #endif
 
 					PotentialTargets innerTargets(activeStack, innerCache, state);
-					BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio);
+					BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio, simulationTurnsCount);
 
 					if(!innerTargets.possibleAttacks.empty())
 					{

+ 5 - 15
AI/BattleAI/BattleEvaluator.h

@@ -37,6 +37,7 @@ class BattleEvaluator
 	float cachedScore;
 	DamageCache damageCache;
 	float strengthRatio;
+	int simulationTurnsCount;
 
 public:
 	BattleAction selectStackAction(const CStack * stack);
@@ -56,15 +57,8 @@ public:
 		PlayerColor playerID,
 		BattleID battleID,
 		BattleSide side,
-		float strengthRatio)
-		:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), strengthRatio(strengthRatio), battleID(battleID)
-	{
-		hb = std::make_shared<HypotheticBattle>(env.get(), cb->getBattle(battleID));
-		damageCache.buildDamageCache(hb, side);
-
-		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
-	}
+		float strengthRatio,
+		int simulationTurnsCount);
 
 	BattleEvaluator(
 		std::shared_ptr<Environment> env,
@@ -75,10 +69,6 @@ public:
 		PlayerColor playerID,
 		BattleID battleID,
 		BattleSide side,
-		float strengthRatio)
-		:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache), strengthRatio(strengthRatio), battleID(battleID)
-	{
-		targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
-		cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
-	}
+		float strengthRatio,
+		int simulationTurnsCount);
 };

+ 13 - 2
AI/BattleAI/BattleExchangeVariant.cpp

@@ -655,11 +655,14 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 		});
 
 	bool canUseAp = true;
-	const int totalTurnsCount = 10;
 
 	std::set<uint32_t> blockedShooters;
 
-	for(int exchangeTurn = 0; exchangeTurn < totalTurnsCount; exchangeTurn++)
+	int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
+		? simulationTurnsCount
+		: turn + turnOrder.size();
+
+	for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
 	{
 		bool isMovingTurm = exchangeTurn < turn;
 		int queueTurn = exchangeTurn >= exchangeUnits.units.size()
@@ -826,6 +829,14 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
 
 	auto score = v.getScore();
 
+	if(simulationTurnsCount < totalTurnsCount)
+	{
+		float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
+
+		score.enemyDamageReduce *= scalingRatio;
+		score.ourDamageReduce *= scalingRatio;
+	}
+
 	if(turn > 0)
 	{
 		auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);

+ 3 - 1
AI/BattleAI/BattleExchangeVariant.h

@@ -132,6 +132,7 @@ private:
 	std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
 	std::vector<battle::Units> turnOrder;
 	float negativeEffectMultiplier;
+	int simulationTurnsCount;
 
 	float scoreValue(const BattleScore & score) const;
 
@@ -149,7 +150,8 @@ public:
 	BattleExchangeEvaluator(
 		std::shared_ptr<CBattleInfoCallback> cb,
 		std::shared_ptr<Environment> env,
-		float strengthRatio): cb(cb), env(env) {
+		float strengthRatio,
+		int simulationTurnsCount): cb(cb), env(env), simulationTurnsCount(simulationTurnsCount){
 		negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio * strengthRatio;
 	}
 

+ 2 - 2
lib/CGameInfoCallback.h

@@ -56,7 +56,7 @@ public:
 
 //	//various
 	virtual int getDate(Date mode=Date::DAY) const = 0; //mode=0 - total days in game, mode=1 - day of week, mode=2 - current week, mode=3 - current month
-//	const StartInfo * getStartInfo(bool beforeRandomization = false)const;
+	virtual const StartInfo * getStartInfo(bool beforeRandomization = false) const = 0;
 	virtual bool isAllowed(SpellID id) const = 0;
 	virtual bool isAllowed(ArtifactID id) const = 0;
 	virtual bool isAllowed(SecondarySkill id) const = 0;
@@ -143,7 +143,7 @@ protected:
 public:
 	//various
 	int getDate(Date mode=Date::DAY)const override; //mode=0 - total days in game, mode=1 - day of week, mode=2 - current week, mode=3 - current month
-	virtual const StartInfo * getStartInfo(bool beforeRandomization = false)const;
+	const StartInfo * getStartInfo(bool beforeRandomization = false) const override;
 	bool isAllowed(SpellID id) const override;
 	bool isAllowed(ArtifactID id) const override;
 	bool isAllowed(SecondarySkill id) const override;

+ 1 - 0
test/mock/mock_IGameInfoCallback.h

@@ -17,6 +17,7 @@ class IGameInfoCallbackMock : public IGameInfoCallback
 public:
 	//various
 	MOCK_CONST_METHOD1(getDate, int(Date));
+	MOCK_CONST_METHOD1(getStartInfo, const StartInfo *(bool));
 
 	MOCK_CONST_METHOD1(isAllowed, bool(SpellID));
 	MOCK_CONST_METHOD1(isAllowed, bool(ArtifactID));