BattleExchangeVariant.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026
  1. /*
  2. * BattleAI.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BattleExchangeVariant.h"
  12. #include "../../lib/CStack.h"
  13. AttackerValue::AttackerValue()
  14. : value(0),
  15. isRetaliated(false)
  16. {
  17. }
  18. MoveTarget::MoveTarget()
  19. : positions(), cachedAttack(), score(EvaluationResult::INEFFECTIVE_SCORE)
  20. {
  21. turnsToRich = 1;
  22. }
  23. float BattleExchangeVariant::trackAttack(
  24. const AttackPossibility & ap,
  25. std::shared_ptr<HypotheticBattle> hb,
  26. DamageCache & damageCache)
  27. {
  28. if(!ap.attackerState)
  29. {
  30. logAi->trace("Skipping fake ap attack");
  31. return 0;
  32. }
  33. auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
  34. float attackValue = ap.attackValue();
  35. auto affectedUnits = ap.affectedUnits;
  36. dpsScore.ourDamageReduce += ap.attackerDamageReduce + ap.collateralDamageReduce;
  37. dpsScore.enemyDamageReduce += ap.defenderDamageReduce + ap.shootersBlockedDmg;
  38. attackerValue[attacker->unitId()].value = attackValue;
  39. affectedUnits.push_back(ap.attackerState);
  40. for(auto affectedUnit : affectedUnits)
  41. {
  42. auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
  43. auto damageDealt = unitToUpdate->getAvailableHealth() - affectedUnit->getAvailableHealth();
  44. if(damageDealt > 0)
  45. {
  46. unitToUpdate->damage(damageDealt);
  47. }
  48. if(unitToUpdate->unitSide() == attacker->unitSide())
  49. {
  50. if(unitToUpdate->unitId() == attacker->unitId())
  51. {
  52. unitToUpdate->afterAttack(ap.attack.shooting, false);
  53. #if BATTLE_TRACE_LEVEL>=1
  54. logAi->trace(
  55. "%s -> %s, ap retaliation, %s, dps: %lld",
  56. hb->getForUpdate(ap.attack.defender->unitId())->getDescription(),
  57. ap.attack.attacker->getDescription(),
  58. ap.attack.shooting ? "shot" : "mellee",
  59. damageDealt);
  60. #endif
  61. }
  62. else
  63. {
  64. #if BATTLE_TRACE_LEVEL>=1
  65. logAi->trace(
  66. "%s, ap collateral, dps: %lld",
  67. unitToUpdate->getDescription(),
  68. damageDealt);
  69. #endif
  70. }
  71. }
  72. else
  73. {
  74. if(unitToUpdate->unitId() == ap.attack.defender->unitId())
  75. {
  76. if(unitToUpdate->ableToRetaliate() && !affectedUnit->ableToRetaliate())
  77. {
  78. unitToUpdate->afterAttack(ap.attack.shooting, true);
  79. }
  80. #if BATTLE_TRACE_LEVEL>=1
  81. logAi->trace(
  82. "%s -> %s, ap attack, %s, dps: %lld",
  83. attacker->getDescription(),
  84. ap.attack.defender->getDescription(),
  85. ap.attack.shooting ? "shot" : "mellee",
  86. damageDealt);
  87. #endif
  88. }
  89. else
  90. {
  91. #if BATTLE_TRACE_LEVEL>=1
  92. logAi->trace(
  93. "%s, ap enemy collateral, dps: %lld",
  94. unitToUpdate->getDescription(),
  95. damageDealt);
  96. #endif
  97. }
  98. }
  99. }
  100. #if BATTLE_TRACE_LEVEL >= 1
  101. logAi->trace(
  102. "ap score: our: %2f, enemy: %2f, collateral: %2f, blocked: %2f",
  103. ap.attackerDamageReduce,
  104. ap.defenderDamageReduce,
  105. ap.collateralDamageReduce,
  106. ap.shootersBlockedDmg);
  107. #endif
  108. return attackValue;
  109. }
  110. float BattleExchangeVariant::trackAttack(
  111. std::shared_ptr<StackWithBonuses> attacker,
  112. std::shared_ptr<StackWithBonuses> defender,
  113. bool shooting,
  114. bool isOurAttack,
  115. DamageCache & damageCache,
  116. std::shared_ptr<HypotheticBattle> hb,
  117. bool evaluateOnly)
  118. {
  119. const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
  120. static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
  121. const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
  122. int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb);
  123. float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb);
  124. float attackerDamageReduce = 0;
  125. if(!evaluateOnly)
  126. {
  127. #if BATTLE_TRACE_LEVEL>=1
  128. logAi->trace(
  129. "%s -> %s, normal attack, %s, dps: %lld, %2f",
  130. attacker->getDescription(),
  131. defender->getDescription(),
  132. shooting ? "shot" : "mellee",
  133. attackDamage,
  134. defenderDamageReduce);
  135. #endif
  136. if(isOurAttack)
  137. {
  138. dpsScore.enemyDamageReduce += defenderDamageReduce;
  139. attackerValue[attacker->unitId()].value += defenderDamageReduce;
  140. }
  141. else
  142. dpsScore.ourDamageReduce += defenderDamageReduce;
  143. defender->damage(attackDamage);
  144. attacker->afterAttack(shooting, false);
  145. }
  146. if(!evaluateOnly && defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
  147. {
  148. auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb);
  149. attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb);
  150. #if BATTLE_TRACE_LEVEL>=1
  151. logAi->trace(
  152. "%s -> %s, retaliation, dps: %lld, %2f",
  153. defender->getDescription(),
  154. attacker->getDescription(),
  155. retaliationDamage,
  156. attackerDamageReduce);
  157. #endif
  158. if(isOurAttack)
  159. {
  160. dpsScore.ourDamageReduce += attackerDamageReduce;
  161. attackerValue[attacker->unitId()].isRetaliated = true;
  162. }
  163. else
  164. {
  165. dpsScore.enemyDamageReduce += attackerDamageReduce;
  166. attackerValue[defender->unitId()].value += attackerDamageReduce;
  167. }
  168. attacker->damage(retaliationDamage);
  169. defender->afterAttack(false, true);
  170. }
  171. auto score = defenderDamageReduce - attackerDamageReduce;
  172. #if BATTLE_TRACE_LEVEL>=1
  173. if(!score)
  174. {
  175. logAi->trace("Attack has zero score def:%2f att:%2f", defenderDamageReduce, attackerDamageReduce);
  176. }
  177. #endif
  178. return score;
  179. }
  180. float BattleExchangeEvaluator::scoreValue(const BattleScore & score) const
  181. {
  182. return score.enemyDamageReduce * getPositiveEffectMultiplier() - score.ourDamageReduce * getNegativeEffectMultiplier();
  183. }
  184. EvaluationResult BattleExchangeEvaluator::findBestTarget(
  185. const battle::Unit * activeStack,
  186. PotentialTargets & targets,
  187. DamageCache & damageCache,
  188. std::shared_ptr<HypotheticBattle> hb)
  189. {
  190. EvaluationResult result(targets.bestAction());
  191. if(!activeStack->waited() && !activeStack->acquireState()->hadMorale)
  192. {
  193. #if BATTLE_TRACE_LEVEL>=1
  194. logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
  195. #endif
  196. auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
  197. hbWaited->makeWait(activeStack);
  198. updateReachabilityMap(hbWaited);
  199. for(auto & ap : targets.possibleAttacks)
  200. {
  201. float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
  202. if(score > result.score)
  203. {
  204. result.score = score;
  205. result.bestAttack = ap;
  206. result.wait = true;
  207. #if BATTLE_TRACE_LEVEL >= 1
  208. logAi->trace("New high score %2f", result.score);
  209. #endif
  210. }
  211. }
  212. }
  213. #if BATTLE_TRACE_LEVEL>=1
  214. logAi->trace("Evaluating normal attack for %s", activeStack->getDescription());
  215. #endif
  216. updateReachabilityMap(hb);
  217. if(result.bestAttack.attack.shooting
  218. && !result.bestAttack.defenderDead
  219. && !activeStack->waited()
  220. && hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
  221. {
  222. if(!canBeHitThisTurn(result.bestAttack))
  223. return result; // lets wait
  224. }
  225. for(auto & ap : targets.possibleAttacks)
  226. {
  227. float score = evaluateExchange(ap, 0, targets, damageCache, hb);
  228. bool sameScoreButWaited = vstd::isAlmostEqual(score, result.score) && result.wait;
  229. if(score > result.score || sameScoreButWaited)
  230. {
  231. result.score = score;
  232. result.bestAttack = ap;
  233. result.wait = false;
  234. #if BATTLE_TRACE_LEVEL >= 1
  235. logAi->trace("New high score %2f", result.score);
  236. #endif
  237. }
  238. }
  239. return result;
  240. }
  241. ReachabilityInfo getReachabilityWithEnemyBypass(
  242. const battle::Unit * activeStack,
  243. DamageCache & damageCache,
  244. std::shared_ptr<HypotheticBattle> state)
  245. {
  246. ReachabilityInfo::Parameters params(activeStack, activeStack->getPosition());
  247. if(!params.flying)
  248. {
  249. for(const auto * unit : state->battleAliveUnits())
  250. {
  251. if(unit->unitSide() == activeStack->unitSide())
  252. continue;
  253. auto dmg = damageCache.getOriginalDamage(activeStack, unit, state);
  254. auto turnsToKill = unit->getAvailableHealth() / std::max(dmg, (int64_t)1);
  255. vstd::amin(turnsToKill, 100);
  256. for(auto & hex : unit->getHexes())
  257. if(hex.isAvailable()) //towers can have <0 pos; we don't also want to overwrite side columns
  258. params.destructibleEnemyTurns[hex] = turnsToKill * unit->getMovementRange();
  259. }
  260. params.bypassEnemyStacks = true;
  261. }
  262. return state->getReachability(params);
  263. }
  264. MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
  265. const battle::Unit * activeStack,
  266. PotentialTargets & targets,
  267. DamageCache & damageCache,
  268. std::shared_ptr<HypotheticBattle> hb)
  269. {
  270. MoveTarget result;
  271. BattleExchangeVariant ev;
  272. logAi->trace("Find move towards unreachable. Enemies count %d", targets.unreachableEnemies.size());
  273. if(targets.unreachableEnemies.empty())
  274. return result;
  275. auto speed = activeStack->getMovementRange();
  276. if(speed == 0)
  277. return result;
  278. updateReachabilityMap(hb);
  279. auto dists = getReachabilityWithEnemyBypass(activeStack, damageCache, hb);
  280. auto flying = activeStack->hasBonusOfType(BonusType::FLYING);
  281. for(const battle::Unit * enemy : targets.unreachableEnemies)
  282. {
  283. logAi->trace(
  284. "Checking movement towards %d of %s",
  285. enemy->getCount(),
  286. enemy->creatureId().toCreature()->getNameSingularTranslated());
  287. auto distance = dists.distToNearestNeighbour(activeStack, enemy);
  288. if(distance >= GameConstants::BFIELD_SIZE)
  289. continue;
  290. if(distance <= speed)
  291. continue;
  292. auto turnsToRich = (distance - 1) / speed + 1;
  293. auto hexes = enemy->getSurroundingHexes();
  294. auto enemySpeed = enemy->getMovementRange();
  295. auto speedRatio = speed / static_cast<float>(enemySpeed);
  296. auto multiplier = speedRatio > 1 ? 1 : speedRatio;
  297. for(auto & hex : hexes)
  298. {
  299. // FIXME: provide distance info for Jousting bonus
  300. auto bai = BattleAttackInfo(activeStack, enemy, 0, cb->battleCanShoot(activeStack));
  301. auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb);
  302. attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
  303. auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
  304. score.enemyDamageReduce *= multiplier;
  305. #if BATTLE_TRACE_LEVEL >= 1
  306. logAi->trace("Multiplier: %f, turns: %d, current score %f, new score %f", multiplier, turnsToRich, result.score, scoreValue(score));
  307. #endif
  308. if(result.score < scoreValue(score)
  309. || (result.turnsToRich > turnsToRich && vstd::isAlmostEqual(result.score, scoreValue(score))))
  310. {
  311. result.score = scoreValue(score);
  312. result.positions.clear();
  313. #if BATTLE_TRACE_LEVEL >= 1
  314. logAi->trace("New high score");
  315. #endif
  316. for(const BattleHex & initialEnemyHex : enemy->getAttackableHexes(activeStack))
  317. {
  318. BattleHex enemyHex = initialEnemyHex;
  319. while(!flying && dists.distances[enemyHex] > speed && dists.predecessors.at(enemyHex).isValid())
  320. {
  321. enemyHex = dists.predecessors.at(enemyHex);
  322. if(dists.accessibility[enemyHex] == EAccessibility::ALIVE_STACK)
  323. {
  324. auto defenderToBypass = hb->battleGetUnitByPos(enemyHex);
  325. if(defenderToBypass)
  326. {
  327. #if BATTLE_TRACE_LEVEL >= 1
  328. logAi->trace("Found target to bypass at %d", enemyHex.hex);
  329. #endif
  330. auto attackHex = dists.predecessors[enemyHex];
  331. auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack));
  332. auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb);
  333. auto adjacentStacks = getAdjacentUnits(enemy);
  334. adjacentStacks.push_back(defenderToBypass);
  335. vstd::removeDuplicates(adjacentStacks);
  336. auto bypassScore = calculateExchange(
  337. attackBypass,
  338. dists.distances[attackHex],
  339. targets,
  340. damageCache,
  341. hb,
  342. adjacentStacks);
  343. if(scoreValue(bypassScore) > result.score)
  344. {
  345. result.score = scoreValue(bypassScore);
  346. #if BATTLE_TRACE_LEVEL >= 1
  347. logAi->trace("New high score after bypass %f", scoreValue(bypassScore));
  348. #endif
  349. }
  350. }
  351. }
  352. }
  353. result.positions.push_back(enemyHex);
  354. }
  355. result.cachedAttack = attack;
  356. result.turnsToRich = turnsToRich;
  357. }
  358. }
  359. }
  360. return result;
  361. }
  362. std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
  363. {
  364. std::queue<const battle::Unit *> queue;
  365. std::vector<const battle::Unit *> checkedStacks;
  366. queue.push(blockerUnit);
  367. while(!queue.empty())
  368. {
  369. auto stack = queue.front();
  370. queue.pop();
  371. checkedStacks.push_back(stack);
  372. auto hexes = stack->getSurroundingHexes();
  373. for(auto hex : hexes)
  374. {
  375. auto neighbor = cb->battleGetUnitByPos(hex);
  376. if(neighbor && neighbor->unitSide() == stack->unitSide() && !vstd::contains(checkedStacks, neighbor))
  377. {
  378. queue.push(neighbor);
  379. checkedStacks.push_back(neighbor);
  380. }
  381. }
  382. }
  383. return checkedStacks;
  384. }
  385. ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
  386. const AttackPossibility & ap,
  387. uint8_t turn,
  388. PotentialTargets & targets,
  389. std::shared_ptr<HypotheticBattle> hb,
  390. std::vector<const battle::Unit *> additionalUnits) const
  391. {
  392. ReachabilityData result;
  393. auto hexes = ap.attack.defender->getSurroundingHexes();
  394. if(!ap.attack.shooting) hexes.push_back(ap.from);
  395. std::vector<const battle::Unit *> allReachableUnits = additionalUnits;
  396. for(auto hex : hexes)
  397. {
  398. vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex));
  399. }
  400. if(!ap.attack.attacker->isTurret())
  401. {
  402. for(auto hex : ap.attack.attacker->getHexes())
  403. {
  404. auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex) : getOneTurnReachableUnits(turn, hex);
  405. for(auto unit : unitsReachingAttacker)
  406. {
  407. if(unit->unitSide() != ap.attack.attacker->unitSide())
  408. {
  409. allReachableUnits.push_back(unit);
  410. result.enemyUnitsReachingAttacker.insert(unit->unitId());
  411. }
  412. }
  413. }
  414. }
  415. vstd::removeDuplicates(allReachableUnits);
  416. auto copy = allReachableUnits;
  417. for(auto unit : copy)
  418. {
  419. for(auto adjacentUnit : getAdjacentUnits(unit))
  420. {
  421. auto unitWithBonuses = hb->battleGetUnitByID(adjacentUnit->unitId());
  422. if(vstd::contains(targets.unreachableEnemies, adjacentUnit)
  423. && !vstd::contains(allReachableUnits, unitWithBonuses))
  424. {
  425. allReachableUnits.push_back(unitWithBonuses);
  426. }
  427. }
  428. }
  429. vstd::removeDuplicates(allReachableUnits);
  430. if(!vstd::contains(allReachableUnits, ap.attack.attacker))
  431. {
  432. allReachableUnits.push_back(ap.attack.attacker);
  433. }
  434. if(allReachableUnits.size() < 2)
  435. {
  436. #if BATTLE_TRACE_LEVEL>=1
  437. logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
  438. #endif
  439. return result;
  440. }
  441. for(auto unit : allReachableUnits)
  442. {
  443. auto accessible = !unit->canShoot() || vstd::contains(additionalUnits, unit);
  444. if(!accessible)
  445. {
  446. for(auto hex : unit->getSurroundingHexes())
  447. {
  448. if(ap.attack.defender->coversPos(hex))
  449. {
  450. accessible = true;
  451. }
  452. }
  453. }
  454. if(accessible)
  455. result.melleeAccessible.push_back(unit);
  456. else
  457. result.shooters.push_back(unit);
  458. }
  459. for(int turn = 0; turn < turnOrder.size(); turn++)
  460. {
  461. for(auto unit : turnOrder[turn])
  462. {
  463. if(vstd::contains(allReachableUnits, unit))
  464. result.units[turn].push_back(unit);
  465. }
  466. vstd::erase_if(result.units[turn], [&](const battle::Unit * u) -> bool
  467. {
  468. return !hb->battleGetUnitByID(u->unitId())->alive();
  469. });
  470. }
  471. return result;
  472. }
  473. float BattleExchangeEvaluator::evaluateExchange(
  474. const AttackPossibility & ap,
  475. uint8_t turn,
  476. PotentialTargets & targets,
  477. DamageCache & damageCache,
  478. std::shared_ptr<HypotheticBattle> hb) const
  479. {
  480. BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
  481. #if BATTLE_TRACE_LEVEL >= 1
  482. logAi->trace(
  483. "calculateExchange score +%2f -%2fx%2f = %2f",
  484. score.enemyDamageReduce,
  485. score.ourDamageReduce,
  486. getNegativeEffectMultiplier(),
  487. scoreValue(score));
  488. #endif
  489. return scoreValue(score);
  490. }
  491. BattleScore BattleExchangeEvaluator::calculateExchange(
  492. const AttackPossibility & ap,
  493. uint8_t turn,
  494. PotentialTargets & targets,
  495. DamageCache & damageCache,
  496. std::shared_ptr<HypotheticBattle> hb,
  497. std::vector<const battle::Unit *> additionalUnits) const
  498. {
  499. #if BATTLE_TRACE_LEVEL>=1
  500. logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
  501. #endif
  502. if(cb->battleGetMySide() == BattleSide::LEFT_SIDE
  503. && cb->battleGetGateState() == EGateState::BLOCKED
  504. && ap.attack.defender->coversPos(BattleHex::GATE_BRIDGE))
  505. {
  506. return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
  507. }
  508. std::vector<const battle::Unit *> ourStacks;
  509. std::vector<const battle::Unit *> enemyStacks;
  510. if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
  511. enemyStacks.push_back(ap.attack.defender);
  512. ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb, additionalUnits);
  513. if(exchangeUnits.units.empty())
  514. {
  515. return BattleScore();
  516. }
  517. auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
  518. BattleExchangeVariant v;
  519. for(int exchangeTurn = 0; exchangeTurn < exchangeUnits.units.size(); exchangeTurn++)
  520. {
  521. for(auto unit : exchangeUnits.units.at(exchangeTurn))
  522. {
  523. if(unit->isTurret())
  524. continue;
  525. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
  526. auto & attackerQueue = isOur ? ourStacks : enemyStacks;
  527. auto u = exchangeBattle->getForUpdate(unit->unitId());
  528. if(u->alive() && !vstd::contains(attackerQueue, unit))
  529. {
  530. attackerQueue.push_back(unit);
  531. #if BATTLE_TRACE_LEVEL
  532. logAi->trace("Exchanging: %s", u->getDescription());
  533. #endif
  534. }
  535. }
  536. }
  537. auto melleeAttackers = ourStacks;
  538. vstd::removeDuplicates(melleeAttackers);
  539. vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
  540. {
  541. return cb->battleCanShoot(u);
  542. });
  543. bool canUseAp = true;
  544. std::set<uint32_t> blockedShooters;
  545. int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
  546. ? simulationTurnsCount
  547. : turn + turnOrder.size();
  548. for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
  549. {
  550. bool isMovingTurm = exchangeTurn < turn;
  551. int queueTurn = exchangeTurn >= exchangeUnits.units.size()
  552. ? exchangeUnits.units.size() - 1
  553. : exchangeTurn;
  554. for(auto activeUnit : exchangeUnits.units.at(queueTurn))
  555. {
  556. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, activeUnit, true);
  557. battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
  558. battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
  559. auto attacker = exchangeBattle->getForUpdate(activeUnit->unitId());
  560. auto shooting = exchangeBattle->battleCanShoot(attacker.get())
  561. && !vstd::contains(blockedShooters, attacker->unitId());
  562. if(!attacker->alive())
  563. {
  564. #if BATTLE_TRACE_LEVEL>=1
  565. logAi->trace("Attacker is dead");
  566. #endif
  567. continue;
  568. }
  569. if(isMovingTurm && !shooting
  570. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  571. {
  572. #if BATTLE_TRACE_LEVEL>=1
  573. logAi->trace("Attacker is moving");
  574. #endif
  575. continue;
  576. }
  577. auto targetUnit = ap.attack.defender;
  578. if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
  579. {
  580. #if BATTLE_TRACE_LEVEL>=2
  581. logAi->trace("Best target selector for %s", attacker->getDescription());
  582. #endif
  583. auto estimateAttack = [&](const battle::Unit * u) -> float
  584. {
  585. auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId());
  586. auto score = v.trackAttack(
  587. attacker,
  588. stackWithBonuses,
  589. exchangeBattle->battleCanShoot(stackWithBonuses.get()),
  590. isOur,
  591. damageCache,
  592. hb,
  593. true);
  594. #if BATTLE_TRACE_LEVEL>=2
  595. logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
  596. #endif
  597. return score;
  598. };
  599. auto unitsInOppositeQueueExceptInaccessible = oppositeQueue;
  600. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u)->bool
  601. {
  602. return vstd::contains(exchangeUnits.shooters, u);
  603. });
  604. if(!isOur
  605. && exchangeTurn == 0
  606. && exchangeUnits.units.at(exchangeTurn).at(0)->unitId() != ap.attack.attacker->unitId()
  607. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  608. {
  609. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u) -> bool
  610. {
  611. return u->unitId() == ap.attack.attacker->unitId();
  612. });
  613. }
  614. if(!unitsInOppositeQueueExceptInaccessible.empty())
  615. {
  616. targetUnit = *vstd::maxElementByFun(unitsInOppositeQueueExceptInaccessible, estimateAttack);
  617. }
  618. else
  619. {
  620. auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
  621. {
  622. if(u->unitSide() == attacker->unitSide())
  623. return false;
  624. if(!exchangeBattle->getForUpdate(u->unitId())->alive())
  625. return false;
  626. if(!u->getPosition().isValid())
  627. return false; // e.g. tower shooters
  628. return vstd::contains_if(reachabilityMap.at(u->getPosition()), [&attacker](const battle::Unit * other) -> bool
  629. {
  630. return attacker->unitId() == other->unitId();
  631. });
  632. });
  633. if(!reachable.empty())
  634. {
  635. targetUnit = *vstd::maxElementByFun(reachable, estimateAttack);
  636. }
  637. else
  638. {
  639. #if BATTLE_TRACE_LEVEL>=1
  640. logAi->trace("Battle queue is empty and no reachable enemy.");
  641. #endif
  642. continue;
  643. }
  644. }
  645. }
  646. auto defender = exchangeBattle->getForUpdate(targetUnit->unitId());
  647. const int totalAttacks = attacker->getTotalAttacks(shooting);
  648. if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId()
  649. && targetUnit->unitId() == ap.attack.defender->unitId())
  650. {
  651. v.trackAttack(ap, exchangeBattle, damageCache);
  652. }
  653. else
  654. {
  655. for(int i = 0; i < totalAttacks; i++)
  656. {
  657. v.trackAttack(attacker, defender, shooting, isOur, damageCache, exchangeBattle);
  658. if(!attacker->alive() || !defender->alive())
  659. break;
  660. }
  661. }
  662. if(!shooting)
  663. blockedShooters.insert(defender->unitId());
  664. canUseAp = false;
  665. vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
  666. {
  667. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  668. });
  669. vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
  670. {
  671. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  672. });
  673. }
  674. exchangeBattle->nextRound();
  675. }
  676. // avoid blocking path for stronger stack by weaker stack
  677. // the method checks if all stacks can be placed around enemy
  678. std::map<BattleHex, battle::Units> reachabilityMap;
  679. auto hexes = ap.attack.defender->getSurroundingHexes();
  680. for(auto hex : hexes)
  681. reachabilityMap[hex] = getOneTurnReachableUnits(turn, hex);
  682. auto score = v.getScore();
  683. if(simulationTurnsCount < totalTurnsCount)
  684. {
  685. float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
  686. score.enemyDamageReduce *= scalingRatio;
  687. score.ourDamageReduce *= scalingRatio;
  688. }
  689. if(turn > 0)
  690. {
  691. auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);
  692. score.enemyDamageReduce *= turnMultiplier;
  693. }
  694. #if BATTLE_TRACE_LEVEL>=1
  695. logAi->trace("Exchange score: enemy: %2f, our -%2f", score.enemyDamageReduce, score.ourDamageReduce);
  696. #endif
  697. return score;
  698. }
  699. bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
  700. {
  701. for(auto pos : ap.attack.attacker->getSurroundingHexes())
  702. {
  703. for(auto u : reachabilityMap[pos])
  704. {
  705. if(u->unitSide() != ap.attack.attacker->unitSide())
  706. {
  707. return true;
  708. }
  709. }
  710. }
  711. return false;
  712. }
  713. void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
  714. {
  715. const int TURN_DEPTH = 2;
  716. turnOrder.clear();
  717. hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
  718. for(auto turn : turnOrder)
  719. {
  720. for(auto u : turn)
  721. {
  722. if(!vstd::contains(reachabilityCache, u->unitId()))
  723. {
  724. reachabilityCache[u->unitId()] = hb->getReachability(u);
  725. }
  726. }
  727. }
  728. for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
  729. {
  730. reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
  731. }
  732. }
  733. std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
  734. {
  735. std::vector<const battle::Unit *> result;
  736. for(int i = 0; i < turnOrder.size(); i++, turn++)
  737. {
  738. auto & turnQueue = turnOrder[i];
  739. HypotheticBattle turnBattle(env.get(), cb);
  740. for(const battle::Unit * unit : turnQueue)
  741. {
  742. if(unit->isTurret())
  743. continue;
  744. if(turnBattle.battleCanShoot(unit))
  745. {
  746. result.push_back(unit);
  747. continue;
  748. }
  749. auto unitSpeed = unit->getMovementRange(turn);
  750. auto radius = unitSpeed * (turn + 1);
  751. auto reachabilityIter = reachabilityCache.find(unit->unitId());
  752. assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
  753. ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
  754. bool reachable = unitReachability.distances.at(hex) <= radius;
  755. if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
  756. {
  757. const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
  758. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  759. {
  760. for(BattleHex neighbor : hex.neighbouringTiles())
  761. {
  762. reachable = unitReachability.distances.at(neighbor) <= radius;
  763. if(reachable) break;
  764. }
  765. }
  766. }
  767. if(reachable)
  768. {
  769. result.push_back(unit);
  770. }
  771. }
  772. }
  773. return result;
  774. }
  775. // avoid blocking path for stronger stack by weaker stack
  776. bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
  777. {
  778. const int BLOCKING_THRESHOLD = 70;
  779. const int BLOCKING_OWN_ATTACK_PENALTY = 100;
  780. const int BLOCKING_OWN_MOVE_PENALTY = 1;
  781. float blockingScore = 0;
  782. auto activeUnitDamage = activeUnit->getMinDamage(hb.battleCanShoot(activeUnit)) * activeUnit->getCount();
  783. for(int turn = 0; turn < turnOrder.size(); turn++)
  784. {
  785. auto & turnQueue = turnOrder[turn];
  786. HypotheticBattle turnBattle(env.get(), cb);
  787. auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
  788. unitToUpdate->setPosition(position);
  789. for(const battle::Unit * unit : turnQueue)
  790. {
  791. if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
  792. continue;
  793. auto blockedUnitDamage = unit->getMinDamage(hb.battleCanShoot(unit)) * unit->getCount();
  794. float ratio = blockedUnitDamage / (float)(blockedUnitDamage + activeUnitDamage + 0.01);
  795. auto unitReachability = turnBattle.getReachability(unit);
  796. auto unitSpeed = unit->getMovementRange(turn); // Cached value, to avoid performance hit
  797. for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
  798. {
  799. bool enemyUnit = false;
  800. bool reachable = unitReachability.distances.at(hex) <= unitSpeed;
  801. if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
  802. {
  803. const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
  804. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  805. {
  806. enemyUnit = true;
  807. for(BattleHex neighbor : hex.neighbouringTiles())
  808. {
  809. reachable = unitReachability.distances.at(neighbor) <= unitSpeed;
  810. if(reachable) break;
  811. }
  812. }
  813. }
  814. if(!reachable && std::count(reachabilityMap[hex].begin(), reachabilityMap[hex].end(), unit) > 1)
  815. {
  816. blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
  817. }
  818. }
  819. }
  820. }
  821. #if BATTLE_TRACE_LEVEL>=1
  822. logAi->trace("Position %d, blocking score %f", position.hex, blockingScore);
  823. #endif
  824. return blockingScore > BLOCKING_THRESHOLD;
  825. }