BattleExchangeVariant.cpp 28 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. /*
  2. * BattleAI.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BattleExchangeVariant.h"
  12. #include "BattleEvaluator.h"
  13. #include "../../lib/CStack.h"
  14. #include "tbb/parallel_for.h"
  15. AttackerValue::AttackerValue()
  16. : value(0),
  17. isRetaliated(false)
  18. {
  19. }
  20. MoveTarget::MoveTarget()
  21. : positions(), cachedAttack(), score(EvaluationResult::INEFFECTIVE_SCORE)
  22. {
  23. turnsToReach = 1;
  24. }
  25. float BattleExchangeVariant::trackAttack(
  26. const AttackPossibility & ap,
  27. std::shared_ptr<HypotheticBattle> hb,
  28. DamageCache & damageCache)
  29. {
  30. if(!ap.attackerState)
  31. {
  32. logAi->trace("Skipping fake ap attack");
  33. return 0;
  34. }
  35. auto attacker = hb->getForUpdate(ap.attack.attacker->unitId());
  36. float attackValue = ap.attackValue();
  37. auto affectedUnits = ap.affectedUnits;
  38. dpsScore.ourDamageReduce += ap.attackerDamageReduce + ap.collateralDamageReduce;
  39. dpsScore.enemyDamageReduce += ap.defenderDamageReduce + ap.shootersBlockedDmg;
  40. attackerValue[attacker->unitId()].value = attackValue;
  41. affectedUnits.push_back(ap.attackerState);
  42. for(auto affectedUnit : affectedUnits)
  43. {
  44. auto unitToUpdate = hb->getForUpdate(affectedUnit->unitId());
  45. auto damageDealt = unitToUpdate->getAvailableHealth() - affectedUnit->getAvailableHealth();
  46. if(damageDealt > 0)
  47. {
  48. unitToUpdate->damage(damageDealt);
  49. }
  50. if(unitToUpdate->unitSide() == attacker->unitSide())
  51. {
  52. if(unitToUpdate->unitId() == attacker->unitId())
  53. {
  54. unitToUpdate->afterAttack(ap.attack.shooting, false);
  55. #if BATTLE_TRACE_LEVEL>=1
  56. logAi->trace(
  57. "%s -> %s, ap retaliation, %s, dps: %lld",
  58. hb->getForUpdate(ap.attack.defender->unitId())->getDescription(),
  59. ap.attack.attacker->getDescription(),
  60. ap.attack.shooting ? "shot" : "mellee",
  61. damageDealt);
  62. #endif
  63. }
  64. else
  65. {
  66. #if BATTLE_TRACE_LEVEL>=1
  67. logAi->trace(
  68. "%s, ap collateral, dps: %lld",
  69. unitToUpdate->getDescription(),
  70. damageDealt);
  71. #endif
  72. }
  73. }
  74. else
  75. {
  76. if(unitToUpdate->unitId() == ap.attack.defender->unitId())
  77. {
  78. if(unitToUpdate->ableToRetaliate() && !affectedUnit->ableToRetaliate())
  79. {
  80. unitToUpdate->afterAttack(ap.attack.shooting, true);
  81. }
  82. #if BATTLE_TRACE_LEVEL>=1
  83. logAi->trace(
  84. "%s -> %s, ap attack, %s, dps: %lld",
  85. attacker->getDescription(),
  86. ap.attack.defender->getDescription(),
  87. ap.attack.shooting ? "shot" : "mellee",
  88. damageDealt);
  89. #endif
  90. }
  91. else
  92. {
  93. #if BATTLE_TRACE_LEVEL>=1
  94. logAi->trace(
  95. "%s, ap enemy collateral, dps: %lld",
  96. unitToUpdate->getDescription(),
  97. damageDealt);
  98. #endif
  99. }
  100. }
  101. }
  102. #if BATTLE_TRACE_LEVEL >= 1
  103. logAi->trace(
  104. "ap score: our: %2f, enemy: %2f, collateral: %2f, blocked: %2f",
  105. ap.attackerDamageReduce,
  106. ap.defenderDamageReduce,
  107. ap.collateralDamageReduce,
  108. ap.shootersBlockedDmg);
  109. #endif
  110. return attackValue;
  111. }
  112. float BattleExchangeVariant::trackAttack(
  113. std::shared_ptr<StackWithBonuses> attacker,
  114. std::shared_ptr<StackWithBonuses> defender,
  115. bool shooting,
  116. bool isOurAttack,
  117. DamageCache & damageCache,
  118. std::shared_ptr<HypotheticBattle> hb,
  119. bool evaluateOnly)
  120. {
  121. const std::string cachingStringBlocksRetaliation = "type_BLOCKS_RETALIATION";
  122. static const auto selectorBlocksRetaliation = Selector::type()(BonusType::BLOCKS_RETALIATION);
  123. const bool counterAttacksBlocked = attacker->hasBonus(selectorBlocksRetaliation, cachingStringBlocksRetaliation);
  124. int64_t attackDamage = damageCache.getDamage(attacker.get(), defender.get(), hb);
  125. float defenderDamageReduce = AttackPossibility::calculateDamageReduce(attacker.get(), defender.get(), attackDamage, damageCache, hb);
  126. float attackerDamageReduce = 0;
  127. if(!evaluateOnly)
  128. {
  129. #if BATTLE_TRACE_LEVEL>=1
  130. logAi->trace(
  131. "%s -> %s, normal attack, %s, dps: %lld, %2f",
  132. attacker->getDescription(),
  133. defender->getDescription(),
  134. shooting ? "shot" : "mellee",
  135. attackDamage,
  136. defenderDamageReduce);
  137. #endif
  138. if(isOurAttack)
  139. {
  140. dpsScore.enemyDamageReduce += defenderDamageReduce;
  141. attackerValue[attacker->unitId()].value += defenderDamageReduce;
  142. }
  143. else
  144. dpsScore.ourDamageReduce += defenderDamageReduce;
  145. defender->damage(attackDamage);
  146. attacker->afterAttack(shooting, false);
  147. }
  148. if(!evaluateOnly && defender->alive() && defender->ableToRetaliate() && !counterAttacksBlocked && !shooting)
  149. {
  150. auto retaliationDamage = damageCache.getDamage(defender.get(), attacker.get(), hb);
  151. attackerDamageReduce = AttackPossibility::calculateDamageReduce(defender.get(), attacker.get(), retaliationDamage, damageCache, hb);
  152. #if BATTLE_TRACE_LEVEL>=1
  153. logAi->trace(
  154. "%s -> %s, retaliation, dps: %lld, %2f",
  155. defender->getDescription(),
  156. attacker->getDescription(),
  157. retaliationDamage,
  158. attackerDamageReduce);
  159. #endif
  160. if(isOurAttack)
  161. {
  162. dpsScore.ourDamageReduce += attackerDamageReduce;
  163. attackerValue[attacker->unitId()].isRetaliated = true;
  164. }
  165. else
  166. {
  167. dpsScore.enemyDamageReduce += attackerDamageReduce;
  168. attackerValue[defender->unitId()].value += attackerDamageReduce;
  169. }
  170. attacker->damage(retaliationDamage);
  171. defender->afterAttack(false, true);
  172. }
  173. auto score = defenderDamageReduce - attackerDamageReduce;
  174. #if BATTLE_TRACE_LEVEL>=1
  175. if(!score)
  176. {
  177. logAi->trace("Attack has zero score def:%2f att:%2f", defenderDamageReduce, attackerDamageReduce);
  178. }
  179. #endif
  180. return score;
  181. }
  182. float BattleExchangeEvaluator::scoreValue(const BattleScore & score) const
  183. {
  184. return score.enemyDamageReduce * getPositiveEffectMultiplier() - score.ourDamageReduce * getNegativeEffectMultiplier();
  185. }
  186. EvaluationResult BattleExchangeEvaluator::findBestTarget(
  187. const battle::Unit * activeStack,
  188. PotentialTargets & targets,
  189. DamageCache & damageCache,
  190. std::shared_ptr<HypotheticBattle> hb,
  191. bool siegeDefense)
  192. {
  193. EvaluationResult result(targets.bestAction());
  194. if(!activeStack->waited() && !activeStack->acquireState()->hadMorale)
  195. {
  196. #if BATTLE_TRACE_LEVEL>=1
  197. logAi->trace("Evaluating waited attack for %s", activeStack->getDescription());
  198. #endif
  199. auto hbWaited = std::make_shared<HypotheticBattle>(env.get(), hb);
  200. hbWaited->makeWait(activeStack);
  201. updateReachabilityMap(hbWaited);
  202. for(auto & ap : targets.possibleAttacks)
  203. {
  204. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  205. continue;
  206. float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
  207. if(score > result.score)
  208. {
  209. result.score = score;
  210. result.bestAttack = ap;
  211. result.wait = true;
  212. #if BATTLE_TRACE_LEVEL >= 1
  213. logAi->trace("New high score %2f", result.score);
  214. #endif
  215. }
  216. }
  217. }
  218. #if BATTLE_TRACE_LEVEL>=1
  219. logAi->trace("Evaluating normal attack for %s", activeStack->getDescription());
  220. #endif
  221. updateReachabilityMap(hb);
  222. if(result.bestAttack.attack.shooting
  223. && !result.bestAttack.defenderDead
  224. && !activeStack->waited()
  225. && hb->battleHasShootingPenalty(activeStack, result.bestAttack.dest))
  226. {
  227. if(!canBeHitThisTurn(result.bestAttack))
  228. return result; // lets wait
  229. }
  230. for(auto & ap : targets.possibleAttacks)
  231. {
  232. if (siegeDefense && !hb->battleIsInsideWalls(ap.from))
  233. continue;
  234. float score = evaluateExchange(ap, 0, targets, damageCache, hb);
  235. bool sameScoreButWaited = vstd::isAlmostEqual(score, result.score) && result.wait;
  236. if(score > result.score || sameScoreButWaited)
  237. {
  238. result.score = score;
  239. result.bestAttack = ap;
  240. result.wait = false;
  241. #if BATTLE_TRACE_LEVEL >= 1
  242. logAi->trace("New high score %2f", result.score);
  243. #endif
  244. }
  245. }
  246. return result;
  247. }
  248. ReachabilityInfo getReachabilityWithEnemyBypass(
  249. const battle::Unit * activeStack,
  250. DamageCache & damageCache,
  251. std::shared_ptr<HypotheticBattle> state)
  252. {
  253. ReachabilityInfo::Parameters params(activeStack, activeStack->getPosition());
  254. if(!params.flying)
  255. {
  256. for(const auto * unit : state->battleAliveUnits())
  257. {
  258. if(unit->unitSide() == activeStack->unitSide())
  259. continue;
  260. auto dmg = damageCache.getOriginalDamage(activeStack, unit, state);
  261. auto turnsToKill = unit->getAvailableHealth() / std::max(dmg, (int64_t)1);
  262. vstd::amin(turnsToKill, 100);
  263. for(auto & hex : unit->getHexes())
  264. if(hex.isAvailable()) //towers can have <0 pos; we don't also want to overwrite side columns
  265. params.destructibleEnemyTurns[hex.toInt()] = turnsToKill * unit->getMovementRange();
  266. }
  267. params.bypassEnemyStacks = true;
  268. }
  269. return state->getReachability(params);
  270. }
  271. MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
  272. const battle::Unit * activeStack,
  273. PotentialTargets & targets,
  274. DamageCache & damageCache,
  275. std::shared_ptr<HypotheticBattle> hb)
  276. {
  277. MoveTarget result;
  278. BattleExchangeVariant ev;
  279. logAi->trace("Find move towards unreachable. Enemies count %d", targets.unreachableEnemies.size());
  280. if(targets.unreachableEnemies.empty())
  281. return result;
  282. auto speed = activeStack->getMovementRange();
  283. if(speed == 0)
  284. return result;
  285. updateReachabilityMap(hb);
  286. auto dists = getReachabilityWithEnemyBypass(activeStack, damageCache, hb);
  287. auto flying = activeStack->hasBonusOfType(BonusType::FLYING);
  288. for(const battle::Unit * enemy : targets.unreachableEnemies)
  289. {
  290. logAi->trace(
  291. "Checking movement towards %d of %s",
  292. enemy->getCount(),
  293. LIBRARY->creatures()->getById(enemy->creatureId())->getJsonKey());
  294. auto distance = dists.distToNearestNeighbour(activeStack, enemy);
  295. if(distance >= GameConstants::BFIELD_SIZE)
  296. continue;
  297. if(distance <= speed)
  298. continue;
  299. float penaltyMultiplier = 1.0f; // Default multiplier, no penalty
  300. float closestAllyDistance = std::numeric_limits<float>::max();
  301. for (const battle::Unit* ally : hb->battleAliveUnits())
  302. {
  303. if (ally == activeStack)
  304. continue;
  305. if (ally->unitSide() != activeStack->unitSide())
  306. continue;
  307. float allyDistance = dists.distToNearestNeighbour(ally, enemy);
  308. if (allyDistance < closestAllyDistance)
  309. {
  310. closestAllyDistance = allyDistance;
  311. }
  312. }
  313. // If an ally is closer to the enemy, compute the penaltyMultiplier
  314. if (closestAllyDistance < distance)
  315. {
  316. penaltyMultiplier = closestAllyDistance / distance; // Ratio of distances
  317. }
  318. auto turnsToReach = (distance - 1) / speed + 1;
  319. const BattleHexArray & hexes = enemy->getSurroundingHexes();
  320. auto enemySpeed = enemy->getMovementRange();
  321. auto speedRatio = speed / static_cast<float>(enemySpeed);
  322. auto multiplier = (speedRatio > 1 ? 1 : speedRatio) * penaltyMultiplier;
  323. for(auto & hex : hexes)
  324. {
  325. // FIXME: provide distance info for Jousting bonus
  326. auto bai = BattleAttackInfo(activeStack, enemy, 0, cb->battleCanShoot(activeStack));
  327. auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb);
  328. attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
  329. auto score = calculateExchange(attack, turnsToReach, targets, damageCache, hb);
  330. score.enemyDamageReduce *= multiplier;
  331. #if BATTLE_TRACE_LEVEL >= 1
  332. logAi->trace("Multiplier: %f, turns: %d, current score %f, new score %f", multiplier, turnsToReach, result.score, scoreValue(score));
  333. #endif
  334. if(result.score < scoreValue(score)
  335. || (result.turnsToReach > turnsToReach && vstd::isAlmostEqual(result.score, scoreValue(score))))
  336. {
  337. result.score = scoreValue(score);
  338. result.positions.clear();
  339. #if BATTLE_TRACE_LEVEL >= 1
  340. logAi->trace("New high score");
  341. #endif
  342. for(BattleHex enemyHex : enemy->getAttackableHexes(activeStack))
  343. {
  344. while(!flying && dists.distances[enemyHex.toInt()] > speed && dists.predecessors.at(enemyHex.toInt()).isValid())
  345. {
  346. enemyHex = dists.predecessors.at(enemyHex.toInt());
  347. if(dists.accessibility[enemyHex.toInt()] == EAccessibility::ALIVE_STACK)
  348. {
  349. auto defenderToBypass = hb->battleGetUnitByPos(enemyHex);
  350. if(defenderToBypass)
  351. {
  352. #if BATTLE_TRACE_LEVEL >= 1
  353. logAi->trace("Found target to bypass at %d", enemyHex.toInt());
  354. #endif
  355. auto attackHex = dists.predecessors[enemyHex.toInt()];
  356. auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack));
  357. auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb);
  358. auto adjacentStacks = getAdjacentUnits(enemy);
  359. adjacentStacks.push_back(defenderToBypass);
  360. vstd::removeDuplicates(adjacentStacks);
  361. auto bypassScore = calculateExchange(
  362. attackBypass,
  363. dists.distances[attackHex.toInt()],
  364. targets,
  365. damageCache,
  366. hb,
  367. adjacentStacks);
  368. if(scoreValue(bypassScore) > result.score)
  369. {
  370. result.score = scoreValue(bypassScore);
  371. #if BATTLE_TRACE_LEVEL >= 1
  372. logAi->trace("New high score after bypass %f", scoreValue(bypassScore));
  373. #endif
  374. }
  375. }
  376. }
  377. }
  378. result.positions.insert(enemyHex);
  379. }
  380. result.cachedAttack = attack;
  381. result.turnsToReach = turnsToReach;
  382. }
  383. }
  384. }
  385. return result;
  386. }
  387. battle::Units BattleExchangeEvaluator::getAdjacentUnits(const battle::Unit * blockerUnit) const
  388. {
  389. std::queue<const battle::Unit *> queue;
  390. battle::Units checkedStacks;
  391. queue.push(blockerUnit);
  392. while(!queue.empty())
  393. {
  394. auto stack = queue.front();
  395. queue.pop();
  396. checkedStacks.push_back(stack);
  397. auto const & hexes = stack->getSurroundingHexes();
  398. for(const auto & hex : hexes)
  399. {
  400. auto neighbour = cb->battleGetUnitByPos(hex);
  401. if(neighbour && neighbour->unitSide() == stack->unitSide() && !vstd::contains(checkedStacks, neighbour))
  402. {
  403. queue.push(neighbour);
  404. checkedStacks.push_back(neighbour);
  405. }
  406. }
  407. }
  408. return checkedStacks;
  409. }
  410. ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
  411. const AttackPossibility & ap,
  412. uint8_t turn,
  413. PotentialTargets & targets,
  414. std::shared_ptr<HypotheticBattle> hb,
  415. const battle::Units & additionalUnits) const
  416. {
  417. ReachabilityData result;
  418. auto hexes = ap.attack.defender->getSurroundingHexes();
  419. if(!ap.attack.shooting)
  420. hexes.insert(ap.from);
  421. battle::Units allReachableUnits = additionalUnits;
  422. for(const auto & hex : hexes)
  423. {
  424. vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex));
  425. }
  426. if(!ap.attack.attacker->isTurret())
  427. {
  428. for(const auto & hex : ap.attack.attacker->getHexes())
  429. {
  430. auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex);
  431. for(auto unit : unitsReachingAttacker)
  432. {
  433. if(unit->unitSide() != ap.attack.attacker->unitSide())
  434. {
  435. allReachableUnits.push_back(unit);
  436. result.enemyUnitsReachingAttacker.insert(unit->unitId());
  437. }
  438. }
  439. }
  440. }
  441. vstd::removeDuplicates(allReachableUnits);
  442. auto copy = allReachableUnits;
  443. for(auto unit : copy)
  444. {
  445. for(auto adjacentUnit : getAdjacentUnits(unit))
  446. {
  447. auto unitWithBonuses = hb->battleGetUnitByID(adjacentUnit->unitId());
  448. if(vstd::contains(targets.unreachableEnemies, adjacentUnit)
  449. && !vstd::contains(allReachableUnits, unitWithBonuses))
  450. {
  451. allReachableUnits.push_back(unitWithBonuses);
  452. }
  453. }
  454. }
  455. vstd::removeDuplicates(allReachableUnits);
  456. if(!vstd::contains(allReachableUnits, ap.attack.attacker))
  457. {
  458. allReachableUnits.push_back(ap.attack.attacker);
  459. }
  460. if(allReachableUnits.size() < 2)
  461. {
  462. #if BATTLE_TRACE_LEVEL>=1
  463. logAi->trace("Reachability map contains only %d stacks", allReachableUnits.size());
  464. #endif
  465. return result;
  466. }
  467. for(auto unit : allReachableUnits)
  468. {
  469. auto accessible = !unit->canShoot() || vstd::contains(additionalUnits, unit);
  470. if(!accessible)
  471. {
  472. for(const auto & hex : unit->getSurroundingHexes())
  473. {
  474. if(ap.attack.defender->coversPos(hex))
  475. {
  476. accessible = true;
  477. }
  478. }
  479. }
  480. if(accessible)
  481. result.melleeAccessible.push_back(unit);
  482. else
  483. result.shooters.push_back(unit);
  484. }
  485. for(int turn = 0; turn < turnOrder.size(); turn++)
  486. {
  487. for(auto unit : turnOrder[turn])
  488. {
  489. if(vstd::contains(allReachableUnits, unit))
  490. result.units[turn].push_back(unit);
  491. }
  492. vstd::erase_if(result.units[turn], [&](const battle::Unit * u) -> bool
  493. {
  494. return !hb->battleGetUnitByID(u->unitId())->alive();
  495. });
  496. }
  497. return result;
  498. }
  499. float BattleExchangeEvaluator::evaluateExchange(
  500. const AttackPossibility & ap,
  501. uint8_t turn,
  502. PotentialTargets & targets,
  503. DamageCache & damageCache,
  504. std::shared_ptr<HypotheticBattle> hb) const
  505. {
  506. BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
  507. #if BATTLE_TRACE_LEVEL >= 1
  508. logAi->trace(
  509. "calculateExchange score +%2f -%2fx%2f = %2f",
  510. score.enemyDamageReduce,
  511. score.ourDamageReduce,
  512. getNegativeEffectMultiplier(),
  513. scoreValue(score));
  514. #endif
  515. return scoreValue(score);
  516. }
  517. BattleScore BattleExchangeEvaluator::calculateExchange(
  518. const AttackPossibility & ap,
  519. uint8_t turn,
  520. PotentialTargets & targets,
  521. DamageCache & damageCache,
  522. std::shared_ptr<HypotheticBattle> hb,
  523. const battle::Units & additionalUnits) const
  524. {
  525. #if BATTLE_TRACE_LEVEL>=1
  526. logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.toInt() : ap.from.toInt());
  527. #endif
  528. if(cb->battleGetMySide() == BattleSide::LEFT_SIDE
  529. && cb->battleGetGateState() == EGateState::BLOCKED
  530. && ap.attack.defender->coversPos(BattleHex::GATE_BRIDGE))
  531. {
  532. return BattleScore(EvaluationResult::INEFFECTIVE_SCORE, 0);
  533. }
  534. battle::Units ourStacks;
  535. battle::Units enemyStacks;
  536. if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
  537. enemyStacks.push_back(ap.attack.defender);
  538. ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb, additionalUnits);
  539. if(exchangeUnits.units.empty())
  540. {
  541. return BattleScore();
  542. }
  543. auto exchangeBattle = std::make_shared<HypotheticBattle>(env.get(), hb);
  544. BattleExchangeVariant v;
  545. for(int exchangeTurn = 0; exchangeTurn < exchangeUnits.units.size(); exchangeTurn++)
  546. {
  547. for(auto unit : exchangeUnits.units.at(exchangeTurn))
  548. {
  549. if(unit->isTurret())
  550. continue;
  551. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
  552. auto & attackerQueue = isOur ? ourStacks : enemyStacks;
  553. auto u = exchangeBattle->getForUpdate(unit->unitId());
  554. if(u->alive() && !vstd::contains(attackerQueue, unit))
  555. {
  556. attackerQueue.push_back(unit);
  557. #if BATTLE_TRACE_LEVEL
  558. logAi->trace("Exchanging: %s", u->getDescription());
  559. #endif
  560. }
  561. }
  562. }
  563. auto melleeAttackers = ourStacks;
  564. vstd::removeDuplicates(melleeAttackers);
  565. vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool
  566. {
  567. return cb->battleCanShoot(u);
  568. });
  569. bool canUseAp = true;
  570. std::set<uint32_t> blockedShooters;
  571. int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
  572. ? simulationTurnsCount
  573. : turn + turnOrder.size();
  574. for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
  575. {
  576. bool isMovingTurm = exchangeTurn < turn;
  577. int queueTurn = exchangeTurn >= exchangeUnits.units.size()
  578. ? exchangeUnits.units.size() - 1
  579. : exchangeTurn;
  580. for(auto activeUnit : exchangeUnits.units.at(queueTurn))
  581. {
  582. bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, activeUnit, true);
  583. battle::Units & attackerQueue = isOur ? ourStacks : enemyStacks;
  584. battle::Units & oppositeQueue = isOur ? enemyStacks : ourStacks;
  585. auto attacker = exchangeBattle->getForUpdate(activeUnit->unitId());
  586. auto shooting = exchangeBattle->battleCanShoot(attacker.get())
  587. && !vstd::contains(blockedShooters, attacker->unitId());
  588. if(!attacker->alive())
  589. {
  590. #if BATTLE_TRACE_LEVEL>=1
  591. logAi->trace("Attacker is dead");
  592. #endif
  593. continue;
  594. }
  595. if(isMovingTurm && !shooting
  596. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  597. {
  598. #if BATTLE_TRACE_LEVEL>=1
  599. logAi->trace("Attacker is moving");
  600. #endif
  601. continue;
  602. }
  603. auto targetUnit = ap.attack.defender;
  604. if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive())
  605. {
  606. #if BATTLE_TRACE_LEVEL>=2
  607. logAi->trace("Best target selector for %s", attacker->getDescription());
  608. #endif
  609. auto estimateAttack = [&](const battle::Unit * u) -> float
  610. {
  611. auto stackWithBonuses = exchangeBattle->getForUpdate(u->unitId());
  612. auto score = v.trackAttack(
  613. attacker,
  614. stackWithBonuses,
  615. exchangeBattle->battleCanShoot(stackWithBonuses.get()),
  616. isOur,
  617. damageCache,
  618. hb,
  619. true);
  620. #if BATTLE_TRACE_LEVEL>=2
  621. logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
  622. #endif
  623. return score;
  624. };
  625. auto unitsInOppositeQueueExceptInaccessible = oppositeQueue;
  626. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u)->bool
  627. {
  628. return vstd::contains(exchangeUnits.shooters, u);
  629. });
  630. if(!isOur
  631. && exchangeTurn == 0
  632. && exchangeUnits.units.at(exchangeTurn).at(0)->unitId() != ap.attack.attacker->unitId()
  633. && !vstd::contains(exchangeUnits.enemyUnitsReachingAttacker, attacker->unitId()))
  634. {
  635. vstd::erase_if(unitsInOppositeQueueExceptInaccessible, [&](const battle::Unit * u) -> bool
  636. {
  637. return u->unitId() == ap.attack.attacker->unitId();
  638. });
  639. }
  640. if(!unitsInOppositeQueueExceptInaccessible.empty())
  641. {
  642. targetUnit = *vstd::maxElementByFun(unitsInOppositeQueueExceptInaccessible, estimateAttack);
  643. }
  644. else
  645. {
  646. auto reachable = exchangeBattle->battleGetUnitsIf([this, &exchangeBattle, &attacker](const battle::Unit * u) -> bool
  647. {
  648. if(u->unitSide() == attacker->unitSide())
  649. return false;
  650. if(!exchangeBattle->getForUpdate(u->unitId())->alive())
  651. return false;
  652. if(!u->getPosition().isValid())
  653. return false; // e.g. tower shooters
  654. const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition());
  655. return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool
  656. {
  657. return attacker->unitId() == other->unitId();
  658. });
  659. });
  660. if(!reachable.empty())
  661. {
  662. targetUnit = *vstd::maxElementByFun(reachable, estimateAttack);
  663. }
  664. else
  665. {
  666. #if BATTLE_TRACE_LEVEL>=1
  667. logAi->trace("Battle queue is empty and no reachable enemy.");
  668. #endif
  669. continue;
  670. }
  671. }
  672. }
  673. auto defender = exchangeBattle->getForUpdate(targetUnit->unitId());
  674. const int totalAttacks = attacker->getTotalAttacks(shooting);
  675. if(canUseAp && activeUnit->unitId() == ap.attack.attacker->unitId()
  676. && targetUnit->unitId() == ap.attack.defender->unitId())
  677. {
  678. v.trackAttack(ap, exchangeBattle, damageCache);
  679. }
  680. else
  681. {
  682. for(int i = 0; i < totalAttacks; i++)
  683. {
  684. v.trackAttack(attacker, defender, shooting, isOur, damageCache, exchangeBattle);
  685. if(!attacker->alive() || !defender->alive())
  686. break;
  687. }
  688. }
  689. if(!shooting)
  690. blockedShooters.insert(defender->unitId());
  691. canUseAp = false;
  692. vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool
  693. {
  694. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  695. });
  696. vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool
  697. {
  698. return !exchangeBattle->battleGetUnitByID(u->unitId())->alive();
  699. });
  700. }
  701. exchangeBattle->nextRound();
  702. }
  703. auto score = v.getScore();
  704. if(simulationTurnsCount < totalTurnsCount)
  705. {
  706. float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
  707. score.enemyDamageReduce *= scalingRatio;
  708. score.ourDamageReduce *= scalingRatio;
  709. }
  710. if(turn > 0)
  711. {
  712. auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);
  713. score.enemyDamageReduce *= turnMultiplier;
  714. }
  715. #if BATTLE_TRACE_LEVEL>=1
  716. logAi->trace("Exchange score: enemy: %2f, our -%2f", score.enemyDamageReduce, score.ourDamageReduce);
  717. #endif
  718. return score;
  719. }
  720. bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
  721. {
  722. for(auto pos : ap.attack.attacker->getSurroundingHexes())
  723. {
  724. for(auto u : getOneTurnReachableUnits(0, pos))
  725. {
  726. if(u->unitSide() != ap.attack.attacker->unitSide())
  727. {
  728. return true;
  729. }
  730. }
  731. }
  732. return false;
  733. }
  734. void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb)
  735. {
  736. for(auto turn : turnOrder)
  737. {
  738. for(auto u : turn)
  739. {
  740. if(!vstd::contains(unitReachabilityMap, u->unitId()))
  741. {
  742. unitReachabilityMap[u->unitId()] = hb->getReachability(u);
  743. }
  744. }
  745. }
  746. hexReachabilityPerTurn.clear();
  747. }
  748. void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
  749. {
  750. const int TURN_DEPTH = 2;
  751. turnOrder.clear();
  752. hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
  753. reachabilityMap.update(turnOrder, hb);
  754. }
  755. const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, const BattleHex & hex)
  756. {
  757. auto & turnData = hexReachabilityPerTurn[turn];
  758. if (!turnData.isValid[hex.toInt()])
  759. {
  760. turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
  761. turnData.isValid.set(hex.toInt());
  762. }
  763. return turnData.hexes[hex.toInt()];
  764. }
  765. battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, const BattleHex & hex)
  766. {
  767. battle::Units result;
  768. for(int i = 0; i < turnOrder.size(); i++, turn++)
  769. {
  770. auto & turnQueue = turnOrder[i];
  771. HypotheticBattle turnBattle(env.get(), cb);
  772. for(const battle::Unit * unit : turnQueue)
  773. {
  774. if(unit->isTurret())
  775. continue;
  776. if(turnBattle.battleCanShoot(unit))
  777. {
  778. result.push_back(unit);
  779. continue;
  780. }
  781. auto unitSpeed = unit->getMovementRange(turn);
  782. auto radius = unitSpeed * (turn + 1);
  783. auto reachabilityIter = unitReachabilityMap.find(unit->unitId());
  784. assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call?
  785. ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit);
  786. bool reachable = unitReachability.distances.at(hex.toInt()) <= radius;
  787. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  788. {
  789. const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
  790. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  791. {
  792. for(const BattleHex & neighbour : hex.getNeighbouringTiles())
  793. {
  794. reachable = unitReachability.distances.at(neighbour.toInt()) <= radius;
  795. if(reachable) break;
  796. }
  797. }
  798. }
  799. if(reachable)
  800. {
  801. result.push_back(unit);
  802. }
  803. }
  804. }
  805. return result;
  806. }
  807. const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, const BattleHex & hex) const
  808. {
  809. return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex);
  810. }
  811. // avoid blocking path for stronger stack by weaker stack
  812. bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(const HypotheticBattle & hb, const battle::Unit * activeUnit, const BattleHex & position)
  813. {
  814. const int BLOCKING_THRESHOLD = 70;
  815. const int BLOCKING_OWN_ATTACK_PENALTY = 100;
  816. const int BLOCKING_OWN_MOVE_PENALTY = 1;
  817. float blockingScore = 0;
  818. auto activeUnitDamage = activeUnit->getMinDamage(hb.battleCanShoot(activeUnit)) * activeUnit->getCount();
  819. for(int turn = 0; turn < turnOrder.size(); turn++)
  820. {
  821. auto & turnQueue = turnOrder[turn];
  822. HypotheticBattle turnBattle(env.get(), cb);
  823. auto unitToUpdate = turnBattle.getForUpdate(activeUnit->unitId());
  824. unitToUpdate->setPosition(position);
  825. for(const battle::Unit * unit : turnQueue)
  826. {
  827. if(unit->unitId() == unitToUpdate->unitId() || cb->battleMatchOwner(unit, activeUnit, false))
  828. continue;
  829. auto blockedUnitDamage = unit->getMinDamage(hb.battleCanShoot(unit)) * unit->getCount();
  830. float ratio = blockedUnitDamage / (float)(blockedUnitDamage + activeUnitDamage + 0.01);
  831. auto unitReachability = turnBattle.getReachability(unit);
  832. auto unitSpeed = unit->getMovementRange(turn); // Cached value, to avoid performance hit
  833. for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); ++hex)
  834. {
  835. bool enemyUnit = false;
  836. bool reachable = unitReachability.distances.at(hex.toInt()) <= unitSpeed;
  837. if(!reachable && unitReachability.accessibility[hex.toInt()] == EAccessibility::ALIVE_STACK)
  838. {
  839. const battle::Unit * hexStack = turnBattle.battleGetUnitByPos(hex);
  840. if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
  841. {
  842. enemyUnit = true;
  843. for(const BattleHex & neighbour : hex.getNeighbouringTiles())
  844. {
  845. reachable = unitReachability.distances.at(neighbour.toInt()) <= unitSpeed;
  846. if(reachable) break;
  847. }
  848. }
  849. }
  850. if(!reachable)
  851. {
  852. auto reachableUnits = getOneTurnReachableUnits(0, hex);
  853. if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1)
  854. blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY);
  855. }
  856. }
  857. }
  858. }
  859. #if BATTLE_TRACE_LEVEL>=1
  860. logAi->trace("Position %d, blocking score %f", position.toInt(), blockingScore);
  861. #endif
  862. return blockingScore > BLOCKING_THRESHOLD;
  863. }