BattleExchangeVariant.cpp 28 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. /*
  2. * BattleAI.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BattleExchangeVariant.h"
  12. #include "BattleEvaluator.h"
  13. #include "../../lib/CStack.h"
  14. #include "tbb/parallel_for.h"
  15. AttackerValue::AttackerValue()
  16. : value(0),
  17. isRetaliated(false)
  18. {
  19. }
  20. MoveTarget::MoveTarget()
  21. : positions(), cachedAttack(), score(EvaluationResult::INEFFECTIVE_SCORE)
  22. {
  23. turnsToReach = 1;
  24. }
  25. float BattleExchangeVariant::trackAttack(
  26. const AttackPossibility & ap,
  27. std::shared_ptr<HypotheticBattle> hb,
  28. DamageCache & damageCache)
  29. {
  30. if(!ap.attackerState)
  31. {
  32. logAi->trace("Skipping fake ap attack");
  33. return 0;
  34. }
  35. auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
  36. float attackValue = ap.attackValue();
  37. auto affectedUnits = ap.affectedUnits;
  38. dpsScore.ourDamageReduce += ap.attackerDamageReduce + ap.collateralDamageReduce;
  39. dpsScore.enemyDamageReduce += ap.defenderDamageReduce + ap.shootersBlockedDmg;
  40. attackerValue[attacker->unitId()].value = attackValue;
  41. affectedUnits.push_back(ap.attackerState);
  42. for(auto affectedUnit : affectedUnits)
  43. {
  44. auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
  45. auto damageDealt = unitToUpdate->getAvailableHealth() - affectedUnit->getAvailableHealth();
  46. if(damageDealt > 0)
  47. {
  48. unitToUpdate->damage(damageDealt);
  49. }
  50. if(unitToUpdate->unitSide() == attacker->unitSide())
  51. {
  52. if(unitToUpdate->unitId() == attacker->unitId())
  53. {
  54. unitToUpdate->afterAttack(ap.attack.shooting, false);
  55. #if BATTLE_TRACE_LEVEL>=1
  56. logAi->trace(
  57. "%s -> %s, ap retaliation, %s, dps: %lld",
  58. hb->getForUpdate(ap.attack.defender->unitId())->getDescription(),
  59. ap.attack.attacker->getDescription(),
  60. ap.attack.shooting ? "shot" : "mellee",
  61. damageDealt);
  62. #endif
  63. }
  64. else
  65. {
  66. #if BATTLE_TRACE_LEVEL>=1
  67. logAi->trace(
  68. "%s, ap collateral, dps: %lld",
  69. unitToUpdate->getDescription(),
  70. damageDealt);
  71. #endif
  72. }
  73. }
  74. else
  75. {
  76. if(unitToUpdate->unitId() == ap.attack.defender->unitId())
  77. {
  78. if(unitToUpdate->ableToRetaliate() && !affectedUnit->ableToRetaliate())
  79. {
  80. unitToUpdate->afterAttack(ap.attack.shooting, true);
  81. }
  82. #if BATTLE_TRACE_LEVEL>=1
  83. logAi->trace(
  84. "%s -> %s, ap attack, %s, dps: %lld",
  85. attacker->getDescription(),
  86. ap.attack.defender->getDescription(),
  87. ap.attack.shooting ? "shot" : "mellee",
  88. damageDealt);
  89. #endif
  90. }
  91. else
  92. {
  93. #if BATTLE_TRACE_LEVEL>=1
  94. logAi->trace(
  95. "%s, ap enemy collateral, dps: %lld",
  96. unitToUpdate->getDescription(),
  97. damageDealt);
  98. #endif
  99. }
  100. }
  101. }
  102. #if BATTLE_TRACE_LEVEL >= 1
  103. logAi->trace(
  104. "ap score: our: %2f, enemy: %2f, collateral: %2f, blocked: %2f",
  105. ap.attackerDamageReduce,
  106. ap.defenderDamageReduce,
  107. ap.collateralDamageReduce,
  108. ap.shootersBlockedDmg);
  109. #endif
  110. return attackValue;
  111. }
  112. float BattleExchangeVariant::trackAttack(
  113. std::shared_ptr<StackWithBonuses> attacker,
  114. std::shared_ptr<StackWithBonuses> defender,
  115. bool shooting,
  116. bool isOurAttack,
  117. DamageCache & damageCache,
  118. std::shared_ptr<HypotheticBattle> hb,
  119. bool evaluateOnly)
  120. {
  121. const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
  122. static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
  123. const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
  124. int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb);
  125. float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb);
  126. float attackerDamageReduce = 0;
  127. if(!evaluateOnly)
  128. {
  129. #if BATTLE_TRACE_LEVEL>=1
  130. logAi->trace(
  131. "%s -> %s, normal attack, %s, dps: %lld, %2f",
  132. attacker->getDescription(),
  133. defender->getDescription(),
  134. shooting ? "shot" : "mellee",
  135. attackDamage,
  136. defenderDamageReduce);
  137. #endif
  138. if(isOurAttack)
  139. {
  140. dpsScore.enemyDamageReduce += defenderDamageReduce;
  141. attackerValue[attacker->unitId()].value += defenderDamageReduce;
  142. }
  143. else
  144. dpsScore.ourDamageReduce += defenderDamageReduce;
  145. defender->damage(attackDamage);
  146. attacker->afterAttack(shooting, false);
  147. }
  148. if(!evaluateOnly && defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
  149. {
  150. auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb);
  151. attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb);
  152. #if BATTLE_TRACE_LEVEL>=1
  153. logAi->trace(
  154. "%s -> %s, retaliation, dps: %lld, %2f",
  155. defender->getDescription(),
  156. attacker->getDescription(),
  157. retaliationDamage,
  158. attackerDamageReduce);
  159. #endif
  160. if(isOurAttack)
  161. {
  162. dpsScore.ourDamageReduce += attackerDamageReduce;
  163. attackerValue[attacker->unitId()].isRetaliated = true;
  164. }
  165. else
  166. {
  167. dpsScore.enemyDamageReduce += attackerDamageReduce;
  168. attackerValue[defender->unitId()].value += attackerDamageReduce;
  169. }
  170. attacker->damage(retaliationDamage);
  171. defender->afterAttack(false, true);
  172. }
  173. auto score = defenderDamageReduce - attackerDamageReduce;
  174. #if BATTLE_TRACE_LEVEL>=1
  175. if(!score)
  176. {
  177. logAi->trace("Attack has zero score def:%2f att:%2f", defenderDamageReduce, attackerDamageReduce);
  178. }
  179. #endif
  180. return score;
  181. }
  182. float BattleExchangeEvaluator::scoreValue(const BattleScore & score) const
  183. {
  184. return score.enemyDamageReduce * getPositiveEffectMultiplier() - score.ourDamageReduce * getNegativeEffectMultiplier();
  185. }
  186. EvaluationResult BattleExchangeEvaluator::findBestTarget(
  187. const battle::Unit * activeStack,
  188. PotentialTargets & targets,
  189. DamageCache & damageCache,
  190. std::shared_ptr<HypotheticBattle> hb,
  191. bool siegeDefense)
  192. {
  193. EvaluationResult result(targets.bestAction());
  194. if(!activeStack->waited() && !activeStack->acquireState()->hadMorale)
  195. {
  196. #if BATTLE_TRACE_LEVEL>=1
  197. logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
  198. #endif
  199. auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
  200. hbWaited->makeWait(activeStack);
  201. updateReachabilityMap(hbWaited);
  202. for(auto & ap : targets.possibleAttacks)
  203. {
  204. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  205. continue;
  206. float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
  207. if(score > result.score)
  208. {
  209. result.score = score;
  210. result.bestAttack = ap;
  211. result.wait = true;
  212. #if BATTLE_TRACE_LEVEL >= 1
  213. logAi->trace("New high score %2f", result.score);
  214. #endif
  215. }
  216. }
  217. }
  218. #if BATTLE_TRACE_LEVEL>=1
  219. logAi->trace("Evaluating normal attack for %s", activeStack->getDescription());
  220. #endif
  221. updateReachabilityMap(hb);
  222. if(result.bestAttack.attack.shooting
  223. && !result.bestAttack.defenderDead
  224. && !activeStack->waited()
  225. && hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
  226. {
  227. if(!canBeHitThisTurn(result.bestAttack))
  228. return result; // lets wait
  229. }
  230. for(auto & ap : targets.possibleAttacks)
  231. {
  232. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  233. continue;
  234. float score = evaluateExchange(ap, 0, targets, damageCache, hb);
  235. bool sameScoreButWaited = vstd::isAlmostEqual(score, result.score) && result.wait;
  236. if(score > result.score || sameScoreButWaited)
  237. {
  238. result.score = score;
  239. result.bestAttack = ap;
  240. result.wait = false;
  241. #if BATTLE_TRACE_LEVEL >= 1
  242. logAi->trace("New high score %2f", result.score);
  243. #endif
  244. }
  245. }
  246. return result;
  247. }
  248. ReachabilityInfo getReachabilityWithEnemyBypass(
  249. const battle::Unit * activeStack,
  250. DamageCache & damageCache,
  251. std::shared_ptr<HypotheticBattle> state)
  252. {
  253. ReachabilityInfo::Parameters params(activeStack, activeStack->getPosition());
  254. if(!params.flying)
  255. {
  256. for(const auto * unit : state->battleAliveUnits())
  257. {
  258. if(unit->unitSide() == activeStack->unitSide())
  259. continue;
  260. auto dmg = damageCache.getOriginalDamage(activeStack, unit, state);
  261. auto turnsToKill = unit->getAvailableHealth() / std::max(dmg, (int64_t)1);
  262. vstd::amin(turnsToKill, 100);
  263. for(auto & hex : unit->getHexes())
  264. if(hex.isAvailable()) //towers can have <0 pos; we don't also want to overwrite side columns
  265. params.destructibleEnemyTurns[hex.toInt()] = turnsToKill * unit->getMovementRange();
  266. }
  267. params.bypassEnemyStacks = true;
  268. }
  269. return state->getReachability(params);
  270. }
  271. MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
  272. const battle::Unit * activeStack,
  273. PotentialTargets & targets,
  274. DamageCache & damageCache,
  275. std::shared_ptr<HypotheticBattle> hb)
  276. {
  277. MoveTarget result;
  278. BattleExchangeVariant ev;
  279. logAi->trace("Find move towards unreachable. Enemies count %d", targets.unreachableEnemies.size());
  280. if(targets.unreachableEnemies.empty())
  281. return result;
  282. auto speed = activeStack->getMovementRange();
  283. if(speed == 0)
  284. return result;
  285. updateReachabilityMap(hb);
  286. auto dists = getReachabilityWithEnemyBypass(activeStack, damageCache, hb);
  287. auto flying = activeStack->hasBonusOfType(BonusType::FLYING);
  288. for(const battle::Unit * enemy : targets.unreachableEnemies)
  289. {
  290. logAi->trace(
  291. "Checking movement towards %d of %s",
  292. enemy->getCount(),
  293. enemy->creatureId().toCreature()->getNameSingularTranslated());
  294. auto distance = dists.distToNearestNeighbour(activeStack, enemy);
  295. if(distance >= GameConstants::BFIELD_SIZE)
  296. continue;
  297. if(distance <= speed)
  298. continue;
  299. float penaltyMultiplier = 1.0f; // Default multiplier, no penalty
  300. float closestAllyDistance = std::numeric_limits<float>::max();
  301. for (const battle::Unit* ally : hb->battleAliveUnits())
  302. {
  303. if (ally == activeStack)
  304. continue;
  305. if (ally->unitSide() != activeStack->unitSide())
  306. continue;
  307. float allyDistance = dists.distToNearestNeighbour(ally, enemy);
  308. if (allyDistance < closestAllyDistance)
  309. {
  310. closestAllyDistance = allyDistance;
  311. }
  312. }
  313. // If an ally is closer to the enemy, compute the penaltyMultiplier
  314. if (closestAllyDistance < distance)
  315. {
  316. penaltyMultiplier = closestAllyDistance / distance; // Ratio of distances
  317. }
  318. auto turnsToReach = (distance - 1) / speed + 1;
  319. const BattleHexArray & hexes = enemy->getSurroundingHexes();
  320. auto enemySpeed = enemy->getMovementRange();
  321. auto speedRatio = speed / static_cast<float>(enemySpeed);
  322. auto multiplier = (speedRatio > 1 ? 1 : speedRatio) * penaltyMultiplier;
  323. for(auto & hex : hexes)
  324. {
  325. // FIXME: provide distance info for Jousting bonus
  326. auto bai = BattleAttackInfo(activeStack, enemy, 0, cb->battleCanShoot(activeStack));
  327. auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb);
  328. attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
  329. auto score = calculateExchange(attack, turnsToReach, targets, damageCache, hb);
  330. score.enemyDamageReduce *= multiplier;
  331. #if BATTLE_TRACE_LEVEL >= 1
  332. logAi->trace("Multiplier: %f, turns: %d, current score %f, new score %f", multiplier, turnsToReach, result.score, scoreValue(score));
  333. #endif
  334. if(result.score < scoreValue(score)
  335. || (result.turnsToReach > turnsToReach && vstd::isAlmostEqual(result.score, scoreValue(score))))
  336. {
  337. result.score = scoreValue(score);
  338. result.positions.clear();
  339. #if BATTLE_TRACE_LEVEL >= 1
  340. logAi->trace("New high score");
  341. #endif
  342. for(BattleHex enemyHex : enemy->getAttackableHexes(activeStack))
  343. {
  344. while(!flying && dists.distances[enemyHex.toInt()] > speed && dists.predecessors.at(enemyHex.toInt()).isValid())
  345. {
  346. enemyHex = dists.predecessors.at(enemyHex.toInt());
  347. if(dists.accessibility[enemyHex.toInt()] == EAccessibility::ALIVE_STACK)
  348. {
  349. auto defenderToBypass = hb->battleGetUnitByPos(enemyHex);
  350. if(defenderToBypass)
  351. {
  352. #if BATTLE_TRACE_LEVEL >= 1
  353. logAi->trace("Found target to bypass at %d", enemyHex.hex);
  354. #endif
  355. auto attackHex = dists.predecessors[enemyHex.toInt()];
  356. auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack));
  357. auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb);
  358. auto adjacentStacks = getAdjacentUnits(enemy);
  359. adjacentStacks.push_back(defenderToBypass);
  360. vstd::removeDuplicates(adjacentStacks);
  361. auto bypassScore = calculateExchange(
  362. attackBypass,
  363. dists.distances[attackHex.toInt()],
  364. targets,
  365. damageCache,
  366. hb,
  367. adjacentStacks);
  368. if(scoreValue(bypassScore) > result.score)
  369. {
  370. result.score = scoreValue(bypassScore);
  371. #if BATTLE_TRACE_LEVEL >= 1
  372. logAi->trace("New high score after bypass %f", scoreValue(bypassScore));
  373. #endif
  374. }
  375. }
  376. }
  377. }
  378. result.positions.insert(enemyHex);
  379. }
  380. result.cachedAttack = attack;
  381. result.turnsToReach = turnsToReach;
  382. }
  383. }
  384. }
  385. return result;
  386. }
  387. std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
  388. {
  389. std::queue<const battle::Unit *> queue;
  390. std::vector<const battle::Unit *> checkedStacks;
  391. queue.push(blockerUnit);
  392. while(!queue.empty())
  393. {
  394. auto stack = queue.front();
  395. queue.pop();
  396. checkedStacks.push_back(stack);
  397. auto const & hexes = stack->getSurroundingHexes();
  398. for(auto hex : hexes)
  399. {
  400. auto neighbour = cb->battleGetUnitByPos(hex);
  401. if(neighbour && neighbour->unitSide() == stack->unitSide() && !vstd::contains(checkedStacks, neighbour))
  402. {
  403. queue.push(neighbour);
  404. checkedStacks.push_back(neighbour);
  405. }
  406. }
  407. }
  408. return checkedStacks;
  409. }
  410. ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
  411. const AttackPossibility & ap,
  412. uint8_t turn,
  413. PotentialTargets & targets,
  414. std::shared_ptr<HypotheticBattle> hb,
  415. std::vector<const battle::Unit *> additionalUnits) const
  416. {
  417. ReachabilityData result;
  418. auto hexes = ap.attack.defender->getSurroundingHexes();
  419. if(!ap.attack.shooting)
  420. hexes.insert(ap.from);
  421. std::vector<const battle::Unit *> allReachableUnits = additionalUnits;
  422. for(auto hex : hexes)
  423. {
  424. vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex));
  425. }
  426. if(!ap.attack.attacker->isTurret())
  427. {
  428. for(auto hex : ap.attack.attacker->getHexes())
  429. {
  430. auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex);
  431. for(auto unit : unitsReachingAttacker)
  432. {
  433. if(unit->unitSide() != ap.attack.attacker->unitSide())
  434. {
  435. allReachableUnits.push_back(unit);
  436. result.enemyUnitsReachingAttacker.insert(unit->unitId());
  437. }
  438. }
  439. }
  440. }
  441. vstd::removeDuplicates(allReachableUnits);
  442. auto copy = allReachableUnits;
  443. for(auto unit : copy)
  444. {
  445. for(auto adjacentUnit : getAdjacentUnits(unit))
  446. {
  447. auto unitWithBonuses = hb->battleGetUnitByID(adjacentUnit->unitId());
  448. if(vstd::contains(targets.unreachableEnemies, adjacentUnit)
  449. && !vstd::contains(allReachableUnits, unitWithBonuses))
  450. {
  451. allReachableUnits.push_back(unitWithBonuses);
  452. }
  453. }
  454. }
  455. vstd::removeDuplicates(allReachableUnits);
  456. if(!vstd::contains(allReachableUnits, ap.attack.attacker))
  457. {
  458. allReachableUnits.push_back(ap.attack.attacker);
  459. }
  460. if(allReachableUnits.size() < 2)
  461. {
  462. #if BATTLE_TRACE_LEVEL>=1
  463. logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
  464. #endif
  465. return result;
  466. }
  467. for(auto unit : allReachableUnits)
  468. {
  469. auto accessible = !unit->canShoot() || vstd::contains(additionalUnits, unit);
  470. if(!accessible)
  471. {
  472. for(auto hex : unit->getSurroundingHexes())
  473. {
  474. if(ap.attack.defender->coversPos(hex))
  475. {
  476. accessible = true;
  477. }
  478. }
  479. }
  480. if(accessible)
  481. result.melleeAccessible.push_back(unit);
  482. else
  483. result.shooters.push_back(unit);
  484. }
  485. for(int turn = 0; turn < turnOrder.size(); turn++)
  486. {
  487. for(auto unit : turnOrder[turn])
  488. {
  489. if(vstd::contains(allReachableUnits, unit))
  490. result.units[turn].push_back(unit);
  491. }
  492. vstd::erase_if(result.units[turn], [&](const battle::Unit * u) -> bool
  493. {
  494. return !hb->battleGetUnitByID(u->unitId())->alive();
  495. });
  496. }
  497. return result;
  498. }
  499. float BattleExchangeEvaluator::evaluateExchange(
  500. const AttackPossibility & ap,
  501. uint8_t turn,
  502. PotentialTargets & targets,
  503. DamageCache & damageCache,
  504. std::shared_ptr<HypotheticBattle> hb) const
  505. {
  506. BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
  507. #if BATTLE_TRACE_LEVEL >= 1
  508. logAi->trace(
  509. "calculateExchange score +%2f -%2fx%2f = %2f",
  510. score.enemyDamageReduce,
  511. score.ourDamageReduce,
  512. getNegativeEffectMultiplier(),
  513. scoreValue(score));
  514. #endif
  515. return scoreValue(score);
  516. }
  517. BattleScore BattleExchangeEvaluator::calculateExchange(
  518. const AttackPossibility & ap,
  519. uint8_t turn,
  520. PotentialTargets & targets,
  521. DamageCache & damageCache,
  522. std::shared_ptr<HypotheticBattle> hb,
  523. std::vector<const battle::Unit *> additionalUnits) const
  524. {
  525. #if BATTLE_TRACE_LEVEL>=1
  526. logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
  527. #endif
  528. if(cb->battleGetMySide() == BattleSide::LEFT_SIDE
  529. && cb->battleGetGateState() == EGateState::BLOCKED
  530. && ap.attack.defender->coversPos(BattleHex::GATE_BRIDGE))
  531. {
  532. return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
  533. }
  534. std::vector<const battle::Unit *> ourStacks;
  535. std::vector<const battle::Unit *> enemyStacks;
  536. if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
  537. enemyStacks.push_back(ap.attack.defender);
  538. ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb, additionalUnits);
  539. if(exchangeUnits.units.empty())
  540. {
  541. return BattleScore();
  542. }
  543. auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
  544. BattleExchangeVariant v;
  545. for(int exchangeTurn = 0; exchangeTurn < exchangeUnits.units.size(); exchangeTurn++)
  546. {
  547. for(auto unit : exchangeUnits.units.at(exchangeTurn))
  548. {
  549. if(unit->isTurret())
  550. continue;
  551. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
  552. auto & attackerQueue = isOur ? ourStacks : enemyStacks;
  553. auto u = exchangeBattle->getForUpdate(unit->unitId());
  554. if(u->alive() && !vstd::contains(attackerQueue, unit))
  555. {
  556. attackerQueue.push_back(unit);
  557. #if BATTLE_TRACE_LEVEL
  558. logAi->trace("Exchanging: %s", u->getDescription());
  559. #endif
  560. }
  561. }
  562. }
  563. auto melleeAttackers = ourStacks;
  564. vstd::removeDuplicates(melleeAttackers);
  565. vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
  566. {
  567. return cb->battleCanShoot(u);
  568. });
  569. bool canUseAp = true;
  570. std::set<uint32_t> blockedShooters;
  571. int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
  572. ? simulationTurnsCount
  573. : turn + turnOrder.size();
  574. for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
  575. {
  576. bool isMovingTurm = exchangeTurn < turn;
  577. int queueTurn = exchangeTurn >= exchangeUnits.units.size()
  578. ? exchangeUnits.units.size() - 1
  579. : exchangeTurn;
  580. for(auto activeUnit : exchangeUnits.units.at(queueTurn))
  581. {
  582. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, activeUnit, true);
  583. battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
  584. battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
  585. auto attacker = exchangeBattle->getForUpdate(activeUnit->unitId());
  586. auto shooting = exchangeBattle->battleCanShoot(attacker.get())
  587. && !vstd::contains(blockedShooters, attacker->unitId());
  588. if(!attacker->alive())
  589. {
  590. #if BATTLE_TRACE_LEVEL>=1
  591. logAi->trace("Attacker is dead");
  592. #endif
  593. continue;
  594. }
  595. if(isMovingTurm && !shooting
  596. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  597. {
  598. #if BATTLE_TRACE_LEVEL>=1
  599. logAi->trace("Attacker is moving");
  600. #endif
  601. continue;
  602. }
  603. auto targetUnit = ap.attack.defender;
  604. if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
  605. {
  606. #if BATTLE_TRACE_LEVEL>=2
  607. logAi->trace("Best target selector for %s", attacker->getDescription());
  608. #endif
  609. auto estimateAttack = [&](const battle::Unit * u) -> float
  610. {
  611. auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId());
  612. auto score = v.trackAttack(
  613. attacker,
  614. stackWithBonuses,
  615. exchangeBattle->battleCanShoot(stackWithBonuses.get()),
  616. isOur,
  617. damageCache,
  618. hb,
  619. true);
  620. #if BATTLE_TRACE_LEVEL>=2
  621. logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
  622. #endif
  623. return score;
  624. };
  625. auto unitsInOppositeQueueExceptInaccessible = oppositeQueue;
  626. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u)->bool
  627. {
  628. return vstd::contains(exchangeUnits.shooters, u);
  629. });
  630. if(!isOur
  631. && exchangeTurn == 0
  632. && exchangeUnits.units.at(exchangeTurn).at(0)->unitId() != ap.attack.attacker->unitId()
  633. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  634. {
  635. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u) -> bool
  636. {
  637. return u->unitId() == ap.attack.attacker->unitId();
  638. });
  639. }
  640. if(!unitsInOppositeQueueExceptInaccessible.empty())
  641. {
  642. targetUnit = *vstd::maxElementByFun(unitsInOppositeQueueExceptInaccessible, estimateAttack);
  643. }
  644. else
  645. {
  646. auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
  647. {
  648. if(u->unitSide() == attacker->unitSide())
  649. return false;
  650. if(!exchangeBattle->getForUpdate(u->unitId())->alive())
  651. return false;
  652. if(!u->getPosition().isValid())
  653. return false; // e.g. tower shooters
  654. return vstd::contains_if(reachabilityMap.at(u->getPosition().toInt()), [&attacker](const battle::Unit * other) -> bool
  655. {
  656. return attacker->unitId() == other->unitId();
  657. });
  658. });
  659. if(!reachable.empty())
  660. {
  661. targetUnit = *vstd::maxElementByFun(reachable, estimateAttack);
  662. }
  663. else
  664. {
  665. #if BATTLE_TRACE_LEVEL>=1
  666. logAi->trace("Battle queue is empty and no reachable enemy.");
  667. #endif
  668. continue;
  669. }
  670. }
  671. }
  672. auto defender = exchangeBattle->getForUpdate(targetUnit->unitId());
  673. const int totalAttacks = attacker->getTotalAttacks(shooting);
  674. if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId()
  675. && targetUnit->unitId() == ap.attack.defender->unitId())
  676. {
  677. v.trackAttack(ap, exchangeBattle, damageCache);
  678. }
  679. else
  680. {
  681. for(int i = 0; i < totalAttacks; i++)
  682. {
  683. v.trackAttack(attacker, defender, shooting, isOur, damageCache, exchangeBattle);
  684. if(!attacker->alive() || !defender->alive())
  685. break;
  686. }
  687. }
  688. if(!shooting)
  689. blockedShooters.insert(defender->unitId());
  690. canUseAp = false;
  691. vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
  692. {
  693. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  694. });
  695. vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
  696. {
  697. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  698. });
  699. }
  700. exchangeBattle->nextRound();
  701. }
  702. auto score = v.getScore();
  703. if(simulationTurnsCount < totalTurnsCount)
  704. {
  705. float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
  706. score.enemyDamageReduce *= scalingRatio;
  707. score.ourDamageReduce *= scalingRatio;
  708. }
  709. if(turn > 0)
  710. {
  711. auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);
  712. score.enemyDamageReduce *= turnMultiplier;
  713. }
  714. #if BATTLE_TRACE_LEVEL>=1
  715. logAi->trace("Exchange score: enemy: %2f, our -%2f", score.enemyDamageReduce, score.ourDamageReduce);
  716. #endif
  717. return score;
  718. }
  719. bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
  720. {
  721. for(auto pos : ap.attack.attacker->getSurroundingHexes())
  722. {
  723. for(auto u : reachabilityMap.at(pos.toInt()))
  724. {
  725. if(u->unitSide() != ap.attack.attacker->unitSide())
  726. {
  727. return true;
  728. }
  729. }
  730. }
  731. return false;
  732. }
  733. void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
  734. {
  735. const int TURN_DEPTH = 2;
  736. turnOrder.clear();
  737. hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
  738. for(auto turn : turnOrder)
  739. {
  740. for(auto u : turn)
  741. {
  742. if(!vstd::contains(reachabilityCache, u->unitId()))
  743. {
  744. reachabilityCache[u->unitId()] = hb->getReachability(u);
  745. }
  746. }
  747. }
  748. tbb::parallel_for(tbb::blocked_range<size_t>(0, reachabilityMap.size()), [&](const tbb::blocked_range<size_t> & r)
  749. {
  750. for(auto i = r.begin(); i != r.end(); i++)
  751. reachabilityMap[i] = getOneTurnReachableUnits(0, BattleHex(i));
  752. });
  753. }
  754. std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const
  755. {
  756. std::vector<const battle::Unit *> result;
  757. for(int i = 0; i < turnOrder.size(); i++, turn++)
  758. {
  759. auto & turnQueue = turnOrder[i];
  760. HypotheticBattle turnBattle(env.get(), cb);
  761. for(const battle::Unit * unit : turnQueue)
  762. {
  763. if(unit->isTurret())
  764. continue;
  765. if(turnBattle.battleCanShoot(unit))
  766. {
  767. result.push_back(unit);
  768. continue;
  769. }
  770. auto unitSpeed = unit->getMovementRange(turn);
  771. auto radius = unitSpeed * (turn + 1);
  772. auto reachabilityIter = reachabilityCache.find(unit->unitId());
  773. assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call?
  774. ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
  775. bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
  776. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  777. {
  778. const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
  779. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  780. {
  781. for(BattleHex neighbour : hex.getNeighbouringTiles())
  782. {
  783. reachable = unitReachability.distances.at(neighbour.toInt()) <= radius;
  784. if(reachable) break;
  785. }
  786. }
  787. }
  788. if(reachable)
  789. {
  790. result.push_back(unit);
  791. }
  792. }
  793. }
  794. return result;
  795. }
  796. // avoid blocking path for stronger stack by weaker stack
  797. bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position)
  798. {
  799. const int BLOCKING_THRESHOLD = 70;
  800. const int BLOCKING_OWN_ATTACK_PENALTY = 100;
  801. const int BLOCKING_OWN_MOVE_PENALTY = 1;
  802. float blockingScore = 0;
  803. auto activeUnitDamage = activeUnit->getMinDamage(hb.battleCanShoot(activeUnit)) * activeUnit->getCount();
  804. for(int turn = 0; turn < turnOrder.size(); turn++)
  805. {
  806. auto & turnQueue = turnOrder[turn];
  807. HypotheticBattle turnBattle(env.get(), cb);
  808. auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
  809. unitToUpdate->setPosition(position);
  810. for(const battle::Unit * unit : turnQueue)
  811. {
  812. if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
  813. continue;
  814. auto blockedUnitDamage = unit->getMinDamage(hb.battleCanShoot(unit)) * unit->getCount();
  815. float ratio = blockedUnitDamage / (float)(blockedUnitDamage + activeUnitDamage + 0.01);
  816. auto unitReachability = turnBattle.getReachability(unit);
  817. auto unitSpeed = unit->getMovementRange(turn); // Cached value, to avoid performance hit
  818. for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex++)
  819. {
  820. bool enemyUnit = false;
  821. bool reachable = unitReachability.distances.at(hex.toInt()) <= unitSpeed;
  822. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  823. {
  824. const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
  825. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  826. {
  827. enemyUnit = true;
  828. for(BattleHex neighbour : hex.getNeighbouringTiles())
  829. {
  830. reachable = unitReachability.distances.at(neighbour.toInt()) <= unitSpeed;
  831. if(reachable) break;
  832. }
  833. }
  834. }
  835. if(!reachable && std::count(reachabilityMap[hex.toInt()].begin(), reachabilityMap[hex.toInt()].end(), unit) > 1)
  836. {
  837. blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
  838. }
  839. }
  840. }
  841. }
  842. #if BATTLE_TRACE_LEVEL>=1
  843. logAi->trace("Position %d, blocking score %f", position.hex, blockingScore);
  844. #endif
  845. return blockingScore > BLOCKING_THRESHOLD;
  846. }