BattleExchangeVariant.cpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074
  1. /*
  2. * BattleAI.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BattleExchangeVariant.h"
  12. #include "BattleEvaluator.h"
  13. #include "../../lib/CStack.h"
  14. #include "../../lib/GameLibrary.h"
  15. #include <tbb/parallel_for.h>
  16. AttackerValue::AttackerValue()
  17. : value(0),
  18. isRetaliated(false)
  19. {
  20. }
  21. MoveTarget::MoveTarget()
  22. : positions(), cachedAttack(), score(EvaluationResult::INEFFECTIVE_SCORE)
  23. {
  24. turnsToReach = 1;
  25. }
  26. float BattleExchangeVariant::trackAttack(
  27. const AttackPossibility & ap,
  28. std::shared_ptr<HypotheticBattle> hb,
  29. DamageCache & damageCache)
  30. {
  31. if(!ap.attackerState)
  32. {
  33. logAi->trace("Skipping fake ap attack");
  34. return 0;
  35. }
  36. auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
  37. float attackValue = ap.attackValue();
  38. auto affectedUnits = ap.affectedUnits;
  39. dpsScore.ourDamageReduce += ap.attackerDamageReduce + ap.collateralDamageReduce;
  40. dpsScore.enemyDamageReduce += ap.defenderDamageReduce + ap.shootersBlockedDmg;
  41. attackerValue[attacker->unitId()].value = attackValue;
  42. affectedUnits.push_back(ap.attackerState);
  43. for(auto affectedUnit : affectedUnits)
  44. {
  45. auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
  46. auto damageDealt = unitToUpdate->getAvailableHealth() - affectedUnit->getAvailableHealth();
  47. if(damageDealt > 0)
  48. {
  49. unitToUpdate->damage(damageDealt);
  50. }
  51. if(unitToUpdate->unitSide() == attacker->unitSide())
  52. {
  53. if(unitToUpdate->unitId() == attacker->unitId())
  54. {
  55. unitToUpdate->afterAttack(ap.attack.shooting, false);
  56. #if BATTLE_TRACE_LEVEL>=1
  57. logAi->trace(
  58. "%s -> %s, ap retaliation, %s, dps: %lld",
  59. hb->getForUpdate(ap.attack.defender->unitId())->getDescription(),
  60. ap.attack.attacker->getDescription(),
  61. ap.attack.shooting ? "shot" : "mellee",
  62. damageDealt);
  63. #endif
  64. }
  65. else
  66. {
  67. #if BATTLE_TRACE_LEVEL>=1
  68. logAi->trace(
  69. "%s, ap collateral, dps: %lld",
  70. unitToUpdate->getDescription(),
  71. damageDealt);
  72. #endif
  73. }
  74. }
  75. else
  76. {
  77. if(unitToUpdate->unitId() == ap.attack.defender->unitId())
  78. {
  79. if(unitToUpdate->ableToRetaliate() && !affectedUnit->ableToRetaliate())
  80. {
  81. unitToUpdate->afterAttack(ap.attack.shooting, true);
  82. }
  83. #if BATTLE_TRACE_LEVEL>=1
  84. logAi->trace(
  85. "%s -> %s, ap attack, %s, dps: %lld",
  86. attacker->getDescription(),
  87. ap.attack.defender->getDescription(),
  88. ap.attack.shooting ? "shot" : "mellee",
  89. damageDealt);
  90. #endif
  91. }
  92. else
  93. {
  94. #if BATTLE_TRACE_LEVEL>=1
  95. logAi->trace(
  96. "%s, ap enemy collateral, dps: %lld",
  97. unitToUpdate->getDescription(),
  98. damageDealt);
  99. #endif
  100. }
  101. }
  102. }
  103. #if BATTLE_TRACE_LEVEL >= 1
  104. logAi->trace(
  105. "ap score: our: %2f, enemy: %2f, collateral: %2f, blocked: %2f",
  106. ap.attackerDamageReduce,
  107. ap.defenderDamageReduce,
  108. ap.collateralDamageReduce,
  109. ap.shootersBlockedDmg);
  110. #endif
  111. return attackValue;
  112. }
  113. float BattleExchangeVariant::trackAttack(
  114. std::shared_ptr<StackWithBonuses> attacker,
  115. std::shared_ptr<StackWithBonuses> defender,
  116. bool shooting,
  117. bool isOurAttack,
  118. DamageCache & damageCache,
  119. std::shared_ptr<HypotheticBattle> hb,
  120. bool evaluateOnly)
  121. {
  122. const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
  123. static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
  124. const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
  125. int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb);
  126. float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb);
  127. float attackerDamageReduce = 0;
  128. if(!evaluateOnly)
  129. {
  130. #if BATTLE_TRACE_LEVEL>=1
  131. logAi->trace(
  132. "%s -> %s, normal attack, %s, dps: %lld, %2f",
  133. attacker->getDescription(),
  134. defender->getDescription(),
  135. shooting ? "shot" : "mellee",
  136. attackDamage,
  137. defenderDamageReduce);
  138. #endif
  139. if(isOurAttack)
  140. {
  141. dpsScore.enemyDamageReduce += defenderDamageReduce;
  142. attackerValue[attacker->unitId()].value += defenderDamageReduce;
  143. }
  144. else
  145. dpsScore.ourDamageReduce += defenderDamageReduce;
  146. defender->damage(attackDamage);
  147. attacker->afterAttack(shooting, false);
  148. }
  149. if(!evaluateOnly && defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
  150. {
  151. auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb);
  152. attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb);
  153. #if BATTLE_TRACE_LEVEL>=1
  154. logAi->trace(
  155. "%s -> %s, retaliation, dps: %lld, %2f",
  156. defender->getDescription(),
  157. attacker->getDescription(),
  158. retaliationDamage,
  159. attackerDamageReduce);
  160. #endif
  161. if(isOurAttack)
  162. {
  163. dpsScore.ourDamageReduce += attackerDamageReduce;
  164. attackerValue[attacker->unitId()].isRetaliated = true;
  165. }
  166. else
  167. {
  168. dpsScore.enemyDamageReduce += attackerDamageReduce;
  169. attackerValue[defender->unitId()].value += attackerDamageReduce;
  170. }
  171. attacker->damage(retaliationDamage);
  172. defender->afterAttack(false, true);
  173. }
  174. auto score = defenderDamageReduce - attackerDamageReduce;
  175. #if BATTLE_TRACE_LEVEL>=1
  176. if(!score)
  177. {
  178. logAi->trace("Attack has zero score def:%2f att:%2f", defenderDamageReduce, attackerDamageReduce);
  179. }
  180. #endif
  181. return score;
  182. }
  183. float BattleExchangeEvaluator::scoreValue(const BattleScore & score) const
  184. {
  185. return score.enemyDamageReduce * getPositiveEffectMultiplier() - score.ourDamageReduce * getNegativeEffectMultiplier();
  186. }
  187. EvaluationResult BattleExchangeEvaluator::findBestTarget(
  188. const battle::Unit * activeStack,
  189. PotentialTargets & targets,
  190. DamageCache & damageCache,
  191. std::shared_ptr<HypotheticBattle> hb,
  192. bool siegeDefense)
  193. {
  194. EvaluationResult result(targets.bestAction());
  195. if(!activeStack->waited() && !activeStack->acquireState()->hadMorale)
  196. {
  197. #if BATTLE_TRACE_LEVEL>=1
  198. logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
  199. #endif
  200. auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
  201. hbWaited->makeWait(activeStack);
  202. updateReachabilityMap(hbWaited);
  203. for(auto & ap : targets.possibleAttacks)
  204. {
  205. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  206. continue;
  207. float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
  208. if(score > result.score)
  209. {
  210. result.score = score;
  211. result.bestAttack = ap;
  212. result.wait = true;
  213. #if BATTLE_TRACE_LEVEL >= 1
  214. logAi->trace("New high score %2f", result.score);
  215. #endif
  216. }
  217. }
  218. }
  219. #if BATTLE_TRACE_LEVEL>=1
  220. logAi->trace("Evaluating normal attack for %s", activeStack->getDescription());
  221. #endif
  222. updateReachabilityMap(hb);
  223. if(result.bestAttack.attack.shooting
  224. && !result.bestAttack.defenderDead
  225. && !activeStack->waited()
  226. && hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
  227. {
  228. if(!canBeHitThisTurn(result.bestAttack))
  229. return result; // lets wait
  230. }
  231. for(auto & ap : targets.possibleAttacks)
  232. {
  233. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  234. continue;
  235. float score = evaluateExchange(ap, 0, targets, damageCache, hb);
  236. bool sameScoreButWaited = vstd::isAlmostEqual(score, result.score) && result.wait;
  237. if(score > result.score || sameScoreButWaited)
  238. {
  239. result.score = score;
  240. result.bestAttack = ap;
  241. result.wait = false;
  242. #if BATTLE_TRACE_LEVEL >= 1
  243. logAi->trace("New high score %2f", result.score);
  244. #endif
  245. }
  246. }
  247. return result;
  248. }
  249. ReachabilityInfo getReachabilityWithEnemyBypass(
  250. const battle::Unit * activeStack,
  251. DamageCache & damageCache,
  252. std::shared_ptr<HypotheticBattle> state)
  253. {
  254. ReachabilityInfo::Parameters params(activeStack, activeStack->getPosition());
  255. if(!params.flying)
  256. {
  257. for(const auto * unit : state->battleAliveUnits())
  258. {
  259. if(unit->unitSide() == activeStack->unitSide())
  260. continue;
  261. auto dmg = damageCache.getOriginalDamage(activeStack, unit, state);
  262. auto turnsToKill = unit->getAvailableHealth() / std::max(dmg, (int64_t)1);
  263. vstd::amin(turnsToKill, 100);
  264. for(auto & hex : unit->getHexes())
  265. if(hex.isAvailable()) //towers can have <0 pos; we don't also want to overwrite side columns
  266. params.destructibleEnemyTurns[hex.toInt()] = turnsToKill * unit->getMovementRange();
  267. }
  268. params.bypassEnemyStacks = true;
  269. }
  270. return state->getReachability(params);
  271. }
  272. MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
  273. const battle::Unit * activeStack,
  274. PotentialTargets & targets,
  275. DamageCache & damageCache,
  276. std::shared_ptr<HypotheticBattle> hb)
  277. {
  278. MoveTarget result;
  279. BattleExchangeVariant ev;
  280. logAi->trace("Find move towards unreachable. Enemies count %d", targets.unreachableEnemies.size());
  281. if(targets.unreachableEnemies.empty())
  282. return result;
  283. auto speed = activeStack->getMovementRange();
  284. if(speed == 0)
  285. return result;
  286. updateReachabilityMap(hb);
  287. auto dists = getReachabilityWithEnemyBypass(activeStack, damageCache, hb);
  288. auto flying = activeStack->hasBonusOfType(BonusType::FLYING);
  289. for(const battle::Unit * enemy : targets.unreachableEnemies)
  290. {
  291. logAi->trace(
  292. "Checking movement towards %d of %s",
  293. enemy->getCount(),
  294. LIBRARY->creatures()->getById(enemy->creatureId())->getJsonKey());
  295. auto distance = dists.distToNearestNeighbour(activeStack, enemy);
  296. if(distance >= GameConstants::BFIELD_SIZE)
  297. continue;
  298. if(distance <= speed)
  299. continue;
  300. float penaltyMultiplier = 1.0f; // Default multiplier, no penalty
  301. float closestAllyDistance = std::numeric_limits<float>::max();
  302. for (const battle::Unit* ally : hb->battleAliveUnits())
  303. {
  304. if (ally == activeStack)
  305. continue;
  306. if (ally->unitSide() != activeStack->unitSide())
  307. continue;
  308. float allyDistance = dists.distToNearestNeighbour(ally, enemy);
  309. if (allyDistance < closestAllyDistance)
  310. {
  311. closestAllyDistance = allyDistance;
  312. }
  313. }
  314. // If an ally is closer to the enemy, compute the penaltyMultiplier
  315. if (closestAllyDistance < distance)
  316. {
  317. penaltyMultiplier = closestAllyDistance / distance; // Ratio of distances
  318. }
  319. auto turnsToReach = (distance - 1) / speed + 1;
  320. const BattleHexArray & hexes = enemy->getAttackableHexes(activeStack);
  321. auto enemySpeed = enemy->getMovementRange();
  322. auto speedRatio = speed / static_cast<float>(enemySpeed);
  323. auto multiplier = (speedRatio > 1 ? 1 : speedRatio) * penaltyMultiplier;
  324. for(auto & hex : hexes)
  325. {
  326. if (!dists.isReachable(hex))
  327. continue;
  328. // FIXME: provide distance info for Jousting bonus
  329. auto bai = BattleAttackInfo(activeStack, enemy, 0, cb->battleCanShoot(activeStack));
  330. auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb);
  331. attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
  332. auto score = calculateExchange(attack, turnsToReach, targets, damageCache, hb);
  333. score.enemyDamageReduce *= multiplier;
  334. #if BATTLE_TRACE_LEVEL >= 1
  335. logAi->trace("Multiplier: %f, turns: %d, current score %f, new score %f", multiplier, turnsToReach, result.score, scoreValue(score));
  336. #endif
  337. if(result.score < scoreValue(score)
  338. || (result.turnsToReach > turnsToReach && vstd::isAlmostEqual(result.score, scoreValue(score))))
  339. {
  340. result.score = scoreValue(score);
  341. result.positions.clear();
  342. #if BATTLE_TRACE_LEVEL >= 1
  343. logAi->trace("New high score");
  344. #endif
  345. BattleHex enemyHex = hex;
  346. while(!flying && dists.distances[enemyHex.toInt()] > speed && dists.predecessors.at(enemyHex.toInt()).isValid())
  347. {
  348. enemyHex = dists.predecessors.at(enemyHex.toInt());
  349. if(dists.accessibility[enemyHex.toInt()] == EAccessibility::ALIVE_STACK)
  350. {
  351. auto defenderToBypass = hb->battleGetUnitByPos(enemyHex);
  352. assert(defenderToBypass != nullptr);
  353. auto attackHex = dists.predecessors[enemyHex.toInt()];
  354. if(defenderToBypass &&
  355. defenderToBypass != enemy &&
  356. vstd::contains(defenderToBypass->getAttackableHexes(activeStack), attackHex))
  357. {
  358. #if BATTLE_TRACE_LEVEL >= 1
  359. logAi->trace("Found target to bypass at %d", enemyHex.toInt());
  360. #endif
  361. auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack));
  362. auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb);
  363. auto adjacentStacks = getAdjacentUnits(enemy);
  364. adjacentStacks.push_back(defenderToBypass);
  365. vstd::removeDuplicates(adjacentStacks);
  366. auto bypassScore = calculateExchange(
  367. attackBypass,
  368. dists.distances[attackHex.toInt()],
  369. targets,
  370. damageCache,
  371. hb,
  372. adjacentStacks);
  373. if(scoreValue(bypassScore) > result.score)
  374. {
  375. result.score = scoreValue(bypassScore);
  376. #if BATTLE_TRACE_LEVEL >= 1
  377. logAi->trace("New high score after bypass %f", scoreValue(bypassScore));
  378. #endif
  379. }
  380. }
  381. }
  382. }
  383. result.positions.insert(enemyHex);
  384. result.cachedAttack = attack;
  385. result.turnsToReach = turnsToReach;
  386. }
  387. }
  388. }
  389. return result;
  390. }
  391. battle::Units BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
  392. {
  393. std::queue<const battle::Unit *> queue;
  394. battle::Units checkedStacks;
  395. queue.push(blockerUnit);
  396. while(!queue.empty())
  397. {
  398. auto stack = queue.front();
  399. queue.pop();
  400. checkedStacks.push_back(stack);
  401. auto const & hexes = stack->getSurroundingHexes();
  402. for(const auto & hex : hexes)
  403. {
  404. auto neighbour = cb->battleGetUnitByPos(hex);
  405. if(neighbour && neighbour->unitSide() == stack->unitSide() && !vstd::contains(checkedStacks, neighbour))
  406. {
  407. queue.push(neighbour);
  408. checkedStacks.push_back(neighbour);
  409. }
  410. }
  411. }
  412. return checkedStacks;
  413. }
  414. ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
  415. const AttackPossibility & ap,
  416. uint8_t turn,
  417. PotentialTargets & targets,
  418. std::shared_ptr<HypotheticBattle> hb,
  419. const battle::Units & additionalUnits) const
  420. {
  421. ReachabilityData result;
  422. auto hexes = ap.attack.defender->getSurroundingHexes();
  423. if(!ap.attack.shooting)
  424. hexes.insert(ap.from);
  425. battle::Units allReachableUnits = additionalUnits;
  426. for(const auto & hex : hexes)
  427. {
  428. vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex));
  429. }
  430. if(!ap.attack.attacker->isTurret())
  431. {
  432. for(const auto & hex : ap.attack.attacker->getHexes())
  433. {
  434. auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex);
  435. for(auto unit : unitsReachingAttacker)
  436. {
  437. if(unit->unitSide() != ap.attack.attacker->unitSide())
  438. {
  439. allReachableUnits.push_back(unit);
  440. result.enemyUnitsReachingAttacker.insert(unit->unitId());
  441. }
  442. }
  443. }
  444. }
  445. vstd::removeDuplicates(allReachableUnits);
  446. auto copy = allReachableUnits;
  447. for(auto unit : copy)
  448. {
  449. for(auto adjacentUnit : getAdjacentUnits(unit))
  450. {
  451. auto unitWithBonuses = hb->battleGetUnitByID(adjacentUnit->unitId());
  452. if(vstd::contains(targets.unreachableEnemies, adjacentUnit)
  453. && !vstd::contains(allReachableUnits, unitWithBonuses))
  454. {
  455. allReachableUnits.push_back(unitWithBonuses);
  456. }
  457. }
  458. }
  459. vstd::removeDuplicates(allReachableUnits);
  460. if(!vstd::contains(allReachableUnits, ap.attack.attacker))
  461. {
  462. allReachableUnits.push_back(ap.attack.attacker);
  463. }
  464. if(allReachableUnits.size() < 2)
  465. {
  466. #if BATTLE_TRACE_LEVEL>=1
  467. logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
  468. #endif
  469. return result;
  470. }
  471. for(auto unit : allReachableUnits)
  472. {
  473. auto accessible = !unit->canShoot() || vstd::contains(additionalUnits, unit);
  474. if(!accessible)
  475. {
  476. for(const auto & hex : unit->getSurroundingHexes())
  477. {
  478. if(ap.attack.defender->coversPos(hex))
  479. {
  480. accessible = true;
  481. }
  482. }
  483. }
  484. if(accessible)
  485. result.melleeAccessible.push_back(unit);
  486. else
  487. result.shooters.push_back(unit);
  488. }
  489. for(int turn = 0; turn < turnOrder.size(); turn++)
  490. {
  491. for(auto unit : turnOrder[turn])
  492. {
  493. if(vstd::contains(allReachableUnits, unit))
  494. result.units[turn].push_back(unit);
  495. }
  496. vstd::erase_if(result.units[turn], [&](const battle::Unit * u) -> bool
  497. {
  498. return !hb->battleGetUnitByID(u->unitId())->alive();
  499. });
  500. }
  501. return result;
  502. }
  503. float BattleExchangeEvaluator::evaluateExchange(
  504. const AttackPossibility & ap,
  505. uint8_t turn,
  506. PotentialTargets & targets,
  507. DamageCache & damageCache,
  508. std::shared_ptr<HypotheticBattle> hb) const
  509. {
  510. BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
  511. #if BATTLE_TRACE_LEVEL >= 1
  512. logAi->trace(
  513. "calculateExchange score +%2f -%2fx%2f = %2f",
  514. score.enemyDamageReduce,
  515. score.ourDamageReduce,
  516. getNegativeEffectMultiplier(),
  517. scoreValue(score));
  518. #endif
  519. return scoreValue(score);
  520. }
  521. BattleScore BattleExchangeEvaluator::calculateExchange(
  522. const AttackPossibility & ap,
  523. uint8_t turn,
  524. PotentialTargets & targets,
  525. DamageCache & damageCache,
  526. std::shared_ptr<HypotheticBattle> hb,
  527. const battle::Units & additionalUnits) const
  528. {
  529. #if BATTLE_TRACE_LEVEL>=1
  530. logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.toInt() : ap.from.toInt());
  531. #endif
  532. if(cb->battleGetMySide() == BattleSide::LEFT_SIDE
  533. && cb->battleGetGateState() == EGateState::BLOCKED
  534. && ap.attack.defender->coversPos(BattleHex::GATE_BRIDGE))
  535. {
  536. return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
  537. }
  538. battle::Units ourStacks;
  539. battle::Units enemyStacks;
  540. if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
  541. enemyStacks.push_back(ap.attack.defender);
  542. ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb, additionalUnits);
  543. if(exchangeUnits.units.empty())
  544. {
  545. return BattleScore();
  546. }
  547. auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
  548. BattleExchangeVariant v;
  549. for(int exchangeTurn = 0; exchangeTurn < exchangeUnits.units.size(); exchangeTurn++)
  550. {
  551. for(auto unit : exchangeUnits.units.at(exchangeTurn))
  552. {
  553. if(unit->isTurret())
  554. continue;
  555. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
  556. auto & attackerQueue = isOur ? ourStacks : enemyStacks;
  557. auto u = exchangeBattle->getForUpdate(unit->unitId());
  558. if(u->alive() && !vstd::contains(attackerQueue, unit))
  559. {
  560. attackerQueue.push_back(unit);
  561. #if BATTLE_TRACE_LEVEL
  562. logAi->trace("Exchanging: %s", u->getDescription());
  563. #endif
  564. }
  565. }
  566. }
  567. auto melleeAttackers = ourStacks;
  568. vstd::removeDuplicates(melleeAttackers);
  569. vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
  570. {
  571. return cb->battleCanShoot(u);
  572. });
  573. bool canUseAp = true;
  574. std::set<uint32_t> blockedShooters;
  575. int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
  576. ? simulationTurnsCount
  577. : turn + turnOrder.size();
  578. for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
  579. {
  580. bool isMovingTurm = exchangeTurn < turn;
  581. int queueTurn = exchangeTurn >= exchangeUnits.units.size()
  582. ? exchangeUnits.units.size() - 1
  583. : exchangeTurn;
  584. for(auto activeUnit : exchangeUnits.units.at(queueTurn))
  585. {
  586. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, activeUnit, true);
  587. battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
  588. battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
  589. auto attacker = exchangeBattle->getForUpdate(activeUnit->unitId());
  590. auto shooting = exchangeBattle->battleCanShoot(attacker.get())
  591. && !vstd::contains(blockedShooters, attacker->unitId());
  592. if(!attacker->alive())
  593. {
  594. #if BATTLE_TRACE_LEVEL>=1
  595. logAi->trace("Attacker is dead");
  596. #endif
  597. continue;
  598. }
  599. if(isMovingTurm && !shooting
  600. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  601. {
  602. #if BATTLE_TRACE_LEVEL>=1
  603. logAi->trace("Attacker is moving");
  604. #endif
  605. continue;
  606. }
  607. auto targetUnit = ap.attack.defender;
  608. if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
  609. {
  610. #if BATTLE_TRACE_LEVEL>=2
  611. logAi->trace("Best target selector for %s", attacker->getDescription());
  612. #endif
  613. auto estimateAttack = [&](const battle::Unit * u) -> float
  614. {
  615. auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId());
  616. auto score = v.trackAttack(
  617. attacker,
  618. stackWithBonuses,
  619. exchangeBattle->battleCanShoot(stackWithBonuses.get()),
  620. isOur,
  621. damageCache,
  622. hb,
  623. true);
  624. #if BATTLE_TRACE_LEVEL>=2
  625. logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
  626. #endif
  627. return score;
  628. };
  629. auto unitsInOppositeQueueExceptInaccessible = oppositeQueue;
  630. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u)->bool
  631. {
  632. return vstd::contains(exchangeUnits.shooters, u);
  633. });
  634. if(!isOur
  635. && exchangeTurn == 0
  636. && exchangeUnits.units.at(exchangeTurn).at(0)->unitId() != ap.attack.attacker->unitId()
  637. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  638. {
  639. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u) -> bool
  640. {
  641. return u->unitId() == ap.attack.attacker->unitId();
  642. });
  643. }
  644. if(!unitsInOppositeQueueExceptInaccessible.empty())
  645. {
  646. targetUnit = *vstd::maxElementByFun(unitsInOppositeQueueExceptInaccessible, estimateAttack);
  647. }
  648. else
  649. {
  650. auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
  651. {
  652. if(u->unitSide() == attacker->unitSide())
  653. return false;
  654. if(!exchangeBattle->getForUpdate(u->unitId())->alive())
  655. return false;
  656. if(!u->getPosition().isValid())
  657. return false; // e.g. tower shooters
  658. const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition());
  659. return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool
  660. {
  661. return attacker->unitId() == other->unitId();
  662. });
  663. });
  664. if(!reachable.empty())
  665. {
  666. targetUnit = *vstd::maxElementByFun(reachable, estimateAttack);
  667. }
  668. else
  669. {
  670. #if BATTLE_TRACE_LEVEL>=1
  671. logAi->trace("Battle queue is empty and no reachable enemy.");
  672. #endif
  673. continue;
  674. }
  675. }
  676. }
  677. auto defender = exchangeBattle->getForUpdate(targetUnit->unitId());
  678. const int totalAttacks = attacker->getTotalAttacks(shooting);
  679. if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId()
  680. && targetUnit->unitId() == ap.attack.defender->unitId())
  681. {
  682. v.trackAttack(ap, exchangeBattle, damageCache);
  683. }
  684. else
  685. {
  686. for(int i = 0; i < totalAttacks; i++)
  687. {
  688. v.trackAttack(attacker, defender, shooting, isOur, damageCache, exchangeBattle);
  689. if(!attacker->alive() || !defender->alive())
  690. break;
  691. }
  692. }
  693. if(!shooting)
  694. blockedShooters.insert(defender->unitId());
  695. canUseAp = false;
  696. vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
  697. {
  698. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  699. });
  700. vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
  701. {
  702. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  703. });
  704. }
  705. exchangeBattle->nextRound();
  706. }
  707. auto score = v.getScore();
  708. if(simulationTurnsCount < totalTurnsCount)
  709. {
  710. float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
  711. score.enemyDamageReduce *= scalingRatio;
  712. score.ourDamageReduce *= scalingRatio;
  713. }
  714. if(turn > 0)
  715. {
  716. auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);
  717. score.enemyDamageReduce *= turnMultiplier;
  718. }
  719. #if BATTLE_TRACE_LEVEL>=1
  720. logAi->trace("Exchange score: enemy: %2f, our -%2f", score.enemyDamageReduce, score.ourDamageReduce);
  721. #endif
  722. return score;
  723. }
  724. bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
  725. {
  726. for(auto pos : ap.attack.attacker->getSurroundingHexes())
  727. {
  728. for(auto u : getOneTurnReachableUnits(0, pos))
  729. {
  730. if(u->unitSide() != ap.attack.attacker->unitSide())
  731. {
  732. return true;
  733. }
  734. }
  735. }
  736. return false;
  737. }
  738. void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb)
  739. {
  740. for(auto turn : turnOrder)
  741. {
  742. for(auto u : turn)
  743. {
  744. if(!vstd::contains(unitReachabilityMap, u->unitId()))
  745. {
  746. unitReachabilityMap[u->unitId()] = hb->getReachability(u);
  747. }
  748. }
  749. }
  750. hexReachabilityPerTurn.clear();
  751. }
  752. void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
  753. {
  754. const int TURN_DEPTH = 2;
  755. turnOrder.clear();
  756. hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
  757. reachabilityMap.update(turnOrder, hb);
  758. }
  759. const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, const BattleHex & hex)
  760. {
  761. auto & turnData = hexReachabilityPerTurn[turn];
  762. if (!turnData.isValid[hex.toInt()])
  763. {
  764. turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
  765. turnData.isValid.set(hex.toInt());
  766. }
  767. return turnData.hexes[hex.toInt()];
  768. }
  769. battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, const BattleHex & hex)
  770. {
  771. battle::Units result;
  772. for(int i = 0; i < turnOrder.size(); i++, turn++)
  773. {
  774. auto & turnQueue = turnOrder[i];
  775. HypotheticBattle turnBattle(env.get(), cb);
  776. for(const battle::Unit * unit : turnQueue)
  777. {
  778. if(unit->isTurret())
  779. continue;
  780. if(turnBattle.battleCanShoot(unit))
  781. {
  782. result.push_back(unit);
  783. continue;
  784. }
  785. auto unitSpeed = unit->getMovementRange(turn);
  786. auto radius = unitSpeed * (turn + 1);
  787. auto reachabilityIter = unitReachabilityMap.find(unit->unitId());
  788. assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call?
  789. ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
  790. bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
  791. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  792. {
  793. const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
  794. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  795. {
  796. for(const BattleHex & neighbour : hex.getNeighbouringTiles())
  797. {
  798. reachable = unitReachability.distances.at(neighbour.toInt()) <= radius;
  799. if(reachable) break;
  800. }
  801. }
  802. }
  803. if(reachable)
  804. {
  805. result.push_back(unit);
  806. }
  807. }
  808. }
  809. return result;
  810. }
  811. const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, const BattleHex & hex) const
  812. {
  813. return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
  814. }
  815. // avoid blocking path for stronger stack by weaker stack
  816. bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(const HypotheticBattle & hb, const battle::Unit * activeUnit, const BattleHex & position)
  817. {
  818. const int BLOCKING_THRESHOLD = 70;
  819. const int BLOCKING_OWN_ATTACK_PENALTY = 100;
  820. const int BLOCKING_OWN_MOVE_PENALTY = 1;
  821. float blockingScore = 0;
  822. auto activeUnitDamage = activeUnit->getMinDamage(hb.battleCanShoot(activeUnit)) * activeUnit->getCount();
  823. for(int turn = 0; turn < turnOrder.size(); turn++)
  824. {
  825. auto & turnQueue = turnOrder[turn];
  826. HypotheticBattle turnBattle(env.get(), cb);
  827. auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
  828. unitToUpdate->setPosition(position);
  829. for(const battle::Unit * unit : turnQueue)
  830. {
  831. if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
  832. continue;
  833. auto blockedUnitDamage = unit->getMinDamage(hb.battleCanShoot(unit)) * unit->getCount();
  834. float ratio = blockedUnitDamage / (float)(blockedUnitDamage + activeUnitDamage + 0.01);
  835. auto unitReachability = turnBattle.getReachability(unit);
  836. auto unitSpeed = unit->getMovementRange(turn); // Cached value, to avoid performance hit
  837. for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); ++hex)
  838. {
  839. bool enemyUnit = false;
  840. bool reachable = unitReachability.distances.at(hex.toInt()) <= unitSpeed;
  841. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  842. {
  843. const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
  844. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  845. {
  846. enemyUnit = true;
  847. for(const BattleHex & neighbour : hex.getNeighbouringTiles())
  848. {
  849. reachable = unitReachability.distances.at(neighbour.toInt()) <= unitSpeed;
  850. if(reachable) break;
  851. }
  852. }
  853. }
  854. if(!reachable)
  855. {
  856. auto reachableUnits = getOneTurnReachableUnits(0, hex);
  857. if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1)
  858. blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
  859. }
  860. }
  861. }
  862. }
  863. #if BATTLE_TRACE_LEVEL>=1
  864. logAi->trace("Position %d, blocking score %f", position.toInt(), blockingScore);
  865. #endif
  866. return blockingScore > BLOCKING_THRESHOLD;
  867. }