|
@@ -11,6 +11,7 @@
|
|
|
#include "BattleExchangeVariant.h"
|
|
|
#include "BattleEvaluator.h"
|
|
|
#include "../../lib/CStack.h"
|
|
|
+#include "tbb/parallel_for.h"
|
|
|
|
|
|
AttackerValue::AttackerValue()
|
|
|
: value(0),
|
|
@@ -470,10 +471,10 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
-std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
|
|
|
+battle::Units BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
|
|
|
{
|
|
|
std::queue<const battle::Unit *> queue;
|
|
|
- std::vector<const battle::Unit *> checkedStacks;
|
|
|
+ battle::Units checkedStacks;
|
|
|
|
|
|
queue.push(blockerUnit);
|
|
|
|
|
@@ -505,7 +506,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
|
|
|
uint8_t turn,
|
|
|
PotentialTargets & targets,
|
|
|
std::shared_ptr<HypotheticBattle> hb,
|
|
|
- std::vector<const battle::Unit *> additionalUnits) const
|
|
|
+ battle::Units additionalUnits) const
|
|
|
{
|
|
|
ReachabilityData result;
|
|
|
|
|
@@ -514,18 +515,18 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
|
|
|
if(!ap.attack.shooting)
|
|
|
hexes.insert(ap.from);
|
|
|
|
|
|
- std::vector<const battle::Unit *> allReachableUnits = additionalUnits;
|
|
|
+ battle::Units allReachableUnits = additionalUnits;
|
|
|
|
|
|
for(auto hex : hexes)
|
|
|
{
|
|
|
- vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
|
|
|
+ vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex));
|
|
|
}
|
|
|
|
|
|
if(!ap.attack.attacker->isTurret())
|
|
|
{
|
|
|
for(auto hex : ap.attack.attacker->getHexes())
|
|
|
{
|
|
|
- auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex);
|
|
|
+ auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex);
|
|
|
for(auto unit : unitsReachingAttacker)
|
|
|
{
|
|
|
if(unit->unitSide() != ap.attack.attacker->unitSide())
|
|
@@ -635,7 +636,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|
|
PotentialTargets & targets,
|
|
|
DamageCache & damageCache,
|
|
|
std::shared_ptr<HypotheticBattle> hb,
|
|
|
- std::vector<const battle::Unit *> additionalUnits) const
|
|
|
+ battle::Units additionalUnits) const
|
|
|
{
|
|
|
#if BATTLE_TRACE_LEVEL>=1
|
|
|
logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
|
|
@@ -648,8 +649,8 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|
|
return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
|
|
|
}
|
|
|
|
|
|
- std::vector<const battle::Unit *> ourStacks;
|
|
|
- std::vector<const battle::Unit *> enemyStacks;
|
|
|
+ battle::Units ourStacks;
|
|
|
+ battle::Units enemyStacks;
|
|
|
|
|
|
if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
|
|
|
enemyStacks.push_back(ap.attack.defender);
|
|
@@ -799,7 +800,9 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
|
|
|
if(!u->getPosition().isValid())
|
|
|
return false; // e.g. tower shooters
|
|
|
|
|
|
- return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
|
|
|
+ const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition());
|
|
|
+
|
|
|
+ return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool
|
|
|
{
|
|
|
return attacker->unitId() == other->unitId();
|
|
|
});
|
|
@@ -886,7 +889,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
|
|
|
{
|
|
|
for(auto pos : ap.attack.attacker->getSurroundingHexes())
|
|
|
{
|
|
|
- for(auto u : reachabilityMap[pos])
|
|
|
+ for(auto u : getOneTurnReachableUnits(0, pos))
|
|
|
{
|
|
|
if(u->unitSide() != ap.attack.attacker->unitSide())
|
|
|
{
|
|
@@ -898,33 +901,48 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
-void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
|
|
|
+void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb)
|
|
|
{
|
|
|
- const int TURN_DEPTH = 2;
|
|
|
-
|
|
|
- turnOrder.clear();
|
|
|
-
|
|
|
- hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
|
|
|
-
|
|
|
for(auto turn : turnOrder)
|
|
|
{
|
|
|
for(auto u : turn)
|
|
|
{
|
|
|
- if(!vstd::contains(reachabilityCache, u->unitId()))
|
|
|
+ if(!vstd::contains(unitReachabilityMap, u->unitId()))
|
|
|
{
|
|
|
- reachabilityCache[u->unitId()] = hb->getReachability(u);
|
|
|
+ unitReachabilityMap[u->unitId()] = hb->getReachability(u);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); ++hex)
|
|
|
+
|
|
|
+ hexReachabilityPerTurn.clear();
|
|
|
+}
|
|
|
+
|
|
|
+void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
|
|
|
+{
|
|
|
+ const int TURN_DEPTH = 2;
|
|
|
+
|
|
|
+ turnOrder.clear();
|
|
|
+
|
|
|
+ hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
|
|
|
+ reachabilityMap.update(turnOrder, hb);
|
|
|
+}
|
|
|
+
|
|
|
+const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
|
|
|
+{
|
|
|
+ auto & turnData = hexReachabilityPerTurn[turn];
|
|
|
+
|
|
|
+ if (!turnData.isValid[hex.toInt()])
|
|
|
{
|
|
|
- reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
|
|
|
+ turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
|
|
|
+ turnData.isValid.set(hex.toInt());
|
|
|
}
|
|
|
+
|
|
|
+ return turnData.hexes[hex.toInt()];
|
|
|
}
|
|
|
|
|
|
-std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
|
|
|
+battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex)
|
|
|
{
|
|
|
- std::vector<const battle::Unit *> result;
|
|
|
+ battle::Units result;
|
|
|
|
|
|
for(int i = 0; i < turnOrder.size(); i++, turn++)
|
|
|
{
|
|
@@ -946,10 +964,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
|
|
|
auto unitSpeed = unit->getMovementRange(turn);
|
|
|
auto radius = unitSpeed * (turn + 1);
|
|
|
|
|
|
- auto reachabilityIter = reachabilityCache.find(unit->unitId());
|
|
|
- assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
|
|
|
+ auto reachabilityIter = unitReachabilityMap.find(unit->unitId());
|
|
|
+ assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call?
|
|
|
|
|
|
- ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
|
|
|
+ ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
|
|
|
|
|
|
bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
|
|
|
|
|
@@ -978,6 +996,11 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
|
|
|
+{
|
|
|
+ return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
|
|
|
+}
|
|
|
+
|
|
|
// avoid blocking path for stronger stack by weaker stack
|
|
|
bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
|
|
|
{
|
|
@@ -1029,9 +1052,11 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if(!reachable && std::count(reachabilityMap[hex].begin(), reachabilityMap[hex].end(), unit) > 1)
|
|
|
+ if(!reachable)
|
|
|
{
|
|
|
- blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
|
|
|
+ auto reachableUnits = getOneTurnReachableUnits(0, hex);
|
|
|
+ if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1)
|
|
|
+ blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
|
|
|
}
|
|
|
}
|
|
|
}
|