SpellTargetsEvaluator.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /*
  2. * SpellTargetsEvaluator.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "../../lib/CStack.h"
  12. #include "../../lib/battle/CBattleInfoCallback.h"
  13. #include "../../lib/spells/Problem.h"
  14. #include "../../lib/CRandomGenerator.h"
  15. #include "SpellTargetsEvaluator.h"
  16. #include <vcmi/spells/Spell.h>
  17. using namespace spells;
  18. std::vector<Target> SpellTargetEvaluator::getViableTargets(const Mechanics * spellMechanics)
  19. {
  20. std::vector<Target> result;
  21. std::vector<AimType> targetTypes = spellMechanics->getTargetTypes();
  22. if(targetTypes.size() != 1) //TODO: support for multi-destination spells
  23. return result;
  24. auto targetType = targetTypes.front();
  25. switch(targetType)
  26. {
  27. case AimType::CREATURE://TODO: support for multi-destination spells
  28. return allTargetableCreatures(spellMechanics);
  29. case AimType::LOCATION:
  30. {
  31. if(spellMechanics->isNeutralSpell())
  32. return defaultLocationSpellHeuristics(
  33. spellMechanics
  34. ); // theoretically anything can be a useful destination, so we balance performance and validity
  35. else
  36. return theBestLocationCasts(spellMechanics);
  37. }
  38. case AimType::NO_TARGET:
  39. return std::vector<Target>(1); //default-constructed target means cast without destination
  40. default:
  41. return result;
  42. }
  43. }
  44. std::vector<Target> SpellTargetEvaluator::defaultLocationSpellHeuristics(const spells::Mechanics * spellMechanics)
  45. {
  46. std::vector<Target> result = allTargetableCreatures(spellMechanics);
  47. auto units = spellMechanics->battle()->battleGetAllUnits(false);
  48. for(const auto * unit : units) //insert a random surrounding hex
  49. {
  50. auto surroundingHexes = unit->getSurroundingHexes();
  51. if(!surroundingHexes.empty())
  52. {
  53. auto randomSurroundingHex = *RandomGeneratorUtil::nextItem(surroundingHexes, CRandomGenerator::getDefault()); // don't think this method bias matter with such small numbers
  54. addIfCanBeCast(spellMechanics, randomSurroundingHex, result);
  55. }
  56. }
  57. return result;
  58. }
  59. std::vector<Target> SpellTargetEvaluator::allTargetableCreatures(const spells::Mechanics * spellMechanics)
  60. {
  61. std::vector<Target> result;
  62. auto units = spellMechanics->battle()->battleGetAllUnits(false);
  63. for(const auto * unit : units)
  64. addIfCanBeCast(spellMechanics, unit->getPosition(), result);
  65. return result;
  66. }
  67. std::vector<Target> SpellTargetEvaluator::theBestLocationCasts(const spells::Mechanics * spellMechanics)
  68. {
  69. std::vector<Target> result;
  70. std::map<BattleHex, std::set<const CStack *>> allCasts;
  71. std::map<BattleHex, std::set<const CStack *>> bestCasts;
  72. for(int i = 0; i < GameConstants::BFIELD_SIZE; i++)
  73. {
  74. BattleHex dest(i);
  75. if(canBeCastAt(spellMechanics, dest))
  76. {
  77. Target target;
  78. target.emplace_back(dest);
  79. auto temp = spellMechanics->getAffectedStacks(target);
  80. std::set<const CStack *> affectedStacks(temp.begin(), temp.end());
  81. allCasts[dest] = affectedStacks;
  82. }
  83. }
  84. for(const auto & cast : allCasts)
  85. {
  86. std::set<BattleHex> worseCasts;
  87. if(isCastHarmful(spellMechanics, cast.second))
  88. continue;
  89. bool isBestCast = true;
  90. for(const auto & bestCast : bestCasts)
  91. {
  92. Compare compare = compareAffectedStacks(spellMechanics, cast.second, bestCast.second);
  93. if(compare == Compare::WORSE || compare == Compare::EQUAL)
  94. {
  95. isBestCast = false;
  96. break;
  97. }
  98. if(compare == Compare::BETTER)
  99. {
  100. worseCasts.insert(bestCast.first);
  101. }
  102. }
  103. if(isBestCast)
  104. {
  105. bestCasts.insert(cast);
  106. for(BattleHex worseCast : worseCasts)
  107. bestCasts.erase(worseCast);
  108. }
  109. }
  110. for(const auto & cast : bestCasts)
  111. {
  112. Destination des(cast.first);
  113. result.push_back({des});
  114. }
  115. return result;
  116. }
  117. bool SpellTargetEvaluator::isCastHarmful(const spells::Mechanics * spellMechanics, const std::set<const CStack *> & affectedStacks)
  118. {
  119. bool isAffectedAlly = false;
  120. bool isAffectedEnemy = false;
  121. for(const CStack * affectedUnit : affectedStacks)
  122. {
  123. if(affectedUnit->unitSide() == spellMechanics->casterSide)
  124. isAffectedAlly = true;
  125. else
  126. isAffectedEnemy = true;
  127. }
  128. return (spellMechanics->isPositiveSpell() && !isAffectedAlly) || (spellMechanics->isNegativeSpell() && !isAffectedEnemy);
  129. }
  130. SpellTargetEvaluator::Compare SpellTargetEvaluator::compareAffectedStacks(
  131. const spells::Mechanics * spellMechanics, const std::set<const CStack *> & newCast, const std::set<const CStack *> & oldCast)
  132. {
  133. if(newCast.size() == oldCast.size())
  134. return newCast == oldCast ? Compare::EQUAL : Compare::DIFFERENT;
  135. auto getAlliedUnits = [&spellMechanics](const std::set<const CStack *> & allUnits) -> std::set<const CStack *>
  136. {
  137. std::set<const CStack *> alliedUnits;
  138. for(auto stack : allUnits)
  139. {
  140. if(stack->unitSide() == spellMechanics->casterSide)
  141. alliedUnits.insert(stack);
  142. }
  143. return alliedUnits;
  144. };
  145. auto getEnemyUnits = [&spellMechanics](const std::set<const CStack *> & allUnits) -> std::set<const CStack *>
  146. {
  147. std::set<const CStack *> enemyUnits;
  148. for(auto stack : allUnits)
  149. {
  150. if(stack->unitSide() != spellMechanics->casterSide)
  151. enemyUnits.insert(stack);
  152. }
  153. return enemyUnits;
  154. };
  155. Compare alliedSubsetComparison = compareAffectedStacksSubset(spellMechanics, getAlliedUnits(newCast), getAlliedUnits(oldCast));
  156. Compare enemySubsetComparison = compareAffectedStacksSubset(spellMechanics, getEnemyUnits(newCast), getEnemyUnits(oldCast));
  157. if(spellMechanics->isPositiveSpell())
  158. enemySubsetComparison = reverse(enemySubsetComparison);
  159. else if(spellMechanics->isNegativeSpell())
  160. alliedSubsetComparison = reverse(alliedSubsetComparison);
  161. std::set<Compare> comparisonResults = {alliedSubsetComparison, enemySubsetComparison};
  162. std::set<std::set<Compare>> possibleBetterResults = {
  163. {Compare::BETTER, Compare::BETTER},
  164. {Compare::BETTER, Compare::EQUAL }
  165. };
  166. std::set<std::set<Compare>> possibleWorstResults = {
  167. {Compare::WORSE, Compare::WORSE},
  168. {Compare::WORSE, Compare::EQUAL}
  169. };
  170. if(possibleBetterResults.find(comparisonResults) != possibleBetterResults.end())
  171. return Compare::BETTER;
  172. if(possibleWorstResults.find(comparisonResults) != possibleWorstResults.end())
  173. return Compare::WORSE;
  174. return Compare::DIFFERENT;
  175. }
  176. SpellTargetEvaluator::Compare SpellTargetEvaluator::compareAffectedStacksSubset(
  177. const spells::Mechanics * spellMechanics, const std::set<const CStack *> & newSubset, const std::set<const CStack *> & oldSubset)
  178. {
  179. if(newSubset.size() == oldSubset.size())
  180. return newSubset == oldSubset ? Compare::EQUAL : Compare::DIFFERENT;
  181. if(oldSubset.size() > newSubset.size())
  182. return reverse(compareAffectedStacksSubset(spellMechanics, oldSubset, newSubset));
  183. const std::set<const CStack *> & biggerSet = newSubset;
  184. const std::set<const CStack *> & smallerSet = oldSubset;
  185. if(std::includes(biggerSet.begin(), biggerSet.end(), smallerSet.begin(), smallerSet.end()))
  186. return Compare::BETTER;
  187. else
  188. return Compare::DIFFERENT;
  189. }
  190. SpellTargetEvaluator::Compare SpellTargetEvaluator::reverse(SpellTargetEvaluator::Compare compare)
  191. {
  192. switch(compare)
  193. {
  194. case Compare::BETTER:
  195. return Compare::WORSE;
  196. case Compare::WORSE:
  197. return Compare::BETTER;
  198. default:
  199. return compare;
  200. }
  201. }
  202. bool SpellTargetEvaluator::canBeCastAt(const spells::Mechanics * spellMechanics, BattleHex hex)
  203. {
  204. detail::ProblemImpl ignored;
  205. Destination des(hex);
  206. return spellMechanics->canBeCastAt({des}, ignored);
  207. }
  208. void SpellTargetEvaluator::addIfCanBeCast(const spells::Mechanics * spellMechanics, BattleHex hex, std::vector<Target> & targets)
  209. {
  210. detail::ProblemImpl ignored;
  211. Destination des(hex);
  212. if(spellMechanics->canBeCastAt({des}, ignored))
  213. targets.push_back({des});
  214. }