state.cpp 17 KB


  1. /*
  2. * state.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "battle/CPlayerBattleCallback.h"
  12. #include "networkPacks/PacksForClientBattle.h"
  13. #include "BAI/v13/encoder.h"
  14. #include "BAI/v13/hexaction.h"
  15. #include "BAI/v13/state.h"
  16. #include "BAI/v13/supplementary_data.h"
  17. #include "common.h"
  18. #include "schema/v13/constants.h"
  19. #include <algorithm>
  20. #include <memory>
  21. namespace MMAI::BAI::V13
  22. {
  23. namespace S13 = Schema::V13;
  24. using GA = Schema::V13::GlobalAttribute;
  25. using PA = Schema::V13::PlayerAttribute;
  26. using HA = Schema::V13::HexAttribute;
  27. using SA = Schema::V13::StackAttribute;
  28. //
  29. // Prevent human errors caused by the Stack / Hex attr overlap
  30. //
  31. static_assert(EI(HA::STACK_SIDE) == EI(SA::SIDE) + S13::STACK_ATTR_OFFSET);
  32. static_assert(EI(HA::STACK_SLOT) == EI(SA::SLOT) + S13::STACK_ATTR_OFFSET);
  33. static_assert(EI(HA::STACK_QUANTITY) == EI(SA::QUANTITY) + S13::STACK_ATTR_OFFSET);
  34. static_assert(EI(HA::STACK_ATTACK) == EI(SA::ATTACK) + S13::STACK_ATTR_OFFSET);
  35. static_assert(EI(HA::STACK_DEFENSE) == EI(SA::DEFENSE) + S13::STACK_ATTR_OFFSET);
  36. static_assert(EI(HA::STACK_SHOTS) == EI(SA::SHOTS) + S13::STACK_ATTR_OFFSET);
  37. static_assert(EI(HA::STACK_DMG_MIN) == EI(SA::DMG_MIN) + S13::STACK_ATTR_OFFSET);
  38. static_assert(EI(HA::STACK_DMG_MAX) == EI(SA::DMG_MAX) + S13::STACK_ATTR_OFFSET);
  39. static_assert(EI(HA::STACK_HP) == EI(SA::HP) + S13::STACK_ATTR_OFFSET);
  40. static_assert(EI(HA::STACK_HP_LEFT) == EI(SA::HP_LEFT) + S13::STACK_ATTR_OFFSET);
  41. static_assert(EI(HA::STACK_SPEED) == EI(SA::SPEED) + S13::STACK_ATTR_OFFSET);
  42. static_assert(EI(HA::STACK_QUEUE) == EI(SA::QUEUE) + S13::STACK_ATTR_OFFSET);
  43. static_assert(EI(HA::STACK_VALUE_ONE) == EI(SA::VALUE_ONE) + S13::STACK_ATTR_OFFSET);
  44. static_assert(EI(HA::STACK_FLAGS1) == EI(SA::FLAGS1) + S13::STACK_ATTR_OFFSET);
  45. static_assert(EI(HA::STACK_FLAGS2) == EI(SA::FLAGS2) + S13::STACK_ATTR_OFFSET);
  46. static_assert(EI(HA::STACK_VALUE_REL) == EI(SA::VALUE_REL) + S13::STACK_ATTR_OFFSET);
  47. static_assert(EI(HA::STACK_VALUE_REL0) == EI(SA::VALUE_REL0) + S13::STACK_ATTR_OFFSET);
  48. static_assert(EI(HA::STACK_VALUE_KILLED_REL) == EI(SA::VALUE_KILLED_REL) + S13::STACK_ATTR_OFFSET);
  49. static_assert(EI(HA::STACK_VALUE_KILLED_ACC_REL0) == EI(SA::VALUE_KILLED_ACC_REL0) + S13::STACK_ATTR_OFFSET);
  50. static_assert(EI(HA::STACK_VALUE_LOST_REL) == EI(SA::VALUE_LOST_REL) + S13::STACK_ATTR_OFFSET);
  51. static_assert(EI(HA::STACK_VALUE_LOST_ACC_REL0) == EI(SA::VALUE_LOST_ACC_REL0) + S13::STACK_ATTR_OFFSET);
  52. static_assert(EI(HA::STACK_DMG_DEALT_REL) == EI(SA::DMG_DEALT_REL) + S13::STACK_ATTR_OFFSET);
  53. static_assert(EI(HA::STACK_DMG_DEALT_ACC_REL0) == EI(SA::DMG_DEALT_ACC_REL0) + S13::STACK_ATTR_OFFSET);
  54. static_assert(EI(HA::STACK_DMG_RECEIVED_REL) == EI(SA::DMG_RECEIVED_REL) + S13::STACK_ATTR_OFFSET);
  55. static_assert(EI(HA::STACK_DMG_RECEIVED_ACC_REL0) == EI(SA::DMG_RECEIVED_ACC_REL0) + S13::STACK_ATTR_OFFSET);
  56. static_assert(EI(StackAttribute::_count) == 25, "whistleblower in case attributes change");
  57. // static
  58. std::vector<float> State::InitNullStack()
  59. {
  60. auto res = std::vector<float>{};
  61. for(int i = 0; i < EI(StackAttribute::_count); ++i)
  62. Encoder::Encode(static_cast<HA>(S13::STACK_ATTR_OFFSET + i), S13::NULL_VALUE_UNENCODED, res);
  63. return res;
  64. };
  65. namespace
  66. {
  67. std::tuple<int, int, int, int> CalcGlobalStats(const CPlayerBattleCallback * battle)
  68. {
  69. int lv = 0;
  70. int lh = 0;
  71. int rv = 0;
  72. int rh = 0;
  73. for(auto & stack : battle->battleGetStacks())
  74. {
  75. auto v = stack->getCount() * Stack::GetValue(stack->unitType());
  76. auto h = stack->getAvailableHealth();
  77. if(stack->unitSide() == BattleSide::ATTACKER)
  78. {
  79. lv += v;
  80. lh += h;
  81. }
  82. else
  83. {
  84. rv += v;
  85. rh += h;
  86. }
  87. }
  88. return {lv, lh, rv, rh};
  89. }
  90. struct AttackLogAggregateData
  91. {
  92. int ldd = 0; // left damage dealt
  93. int ldr = 0; // left damage received
  94. int lvk = 0; // left value killed
  95. int lvl = 0; // left value lost
  96. int rdd = 0; // right damage dealt
  97. int rdr = 0; // right damage received
  98. int rvk = 0; // right value killed
  99. int rvl = 0; // right value lost
  100. };
  101. AttackLogAggregateData ProcessAttackLogs(const std::vector<std::shared_ptr<AttackLog>> & attackLogs, std::map<const CStack *, Stack::Stats> sstats)
  102. {
  103. auto res = AttackLogAggregateData{};
  104. for(auto & [cstack, ss] : sstats)
  105. {
  106. ss.dmgDealtNow = 0;
  107. ss.dmgReceivedNow = 0;
  108. ss.valueKilledNow = 0;
  109. ss.valueLostNow = 0;
  110. }
  111. for(const auto & al : attackLogs)
  112. {
  113. const auto & ald = al->data;
  114. if(ald.cattacker)
  115. {
  116. sstats[ald.cattacker].dmgDealtNow += ald.dmg;
  117. sstats[ald.cattacker].dmgDealtTotal += ald.dmg;
  118. sstats[ald.cattacker].valueKilledNow += ald.value;
  119. sstats[ald.cattacker].valueKilledTotal += ald.value;
  120. if(ald.cattacker->unitSide() == BattleSide::LEFT_SIDE)
  121. {
  122. res.ldd += ald.dmg;
  123. res.lvk += ald.value;
  124. }
  125. else
  126. {
  127. res.rdd += ald.dmg;
  128. res.rvk += ald.value;
  129. }
  130. }
  131. ASSERT(ald.cdefender, "AttackLog cdefender is nullptr!");
  132. sstats[ald.cdefender].dmgReceivedNow += ald.dmg;
  133. sstats[ald.cdefender].dmgReceivedTotal += ald.dmg;
  134. sstats[ald.cdefender].valueLostNow += ald.value;
  135. sstats[ald.cdefender].valueLostTotal += ald.value;
  136. if(ald.cdefender->unitSide() == BattleSide::LEFT_SIDE)
  137. {
  138. res.ldr += ald.dmg;
  139. res.lvl += ald.value;
  140. }
  141. else
  142. {
  143. res.rdr += ald.dmg;
  144. res.rvl += ald.value;
  145. }
  146. }
  147. return res;
  148. }
  149. }
  150. State::State(int version_, const std::string & colorname, const CPlayerBattleCallback * battle, bool enableTransitions)
  151. : version_(version_)
  152. , battle(battle)
  153. , enableTransitions(enableTransitions)
  154. , colorname(colorname)
  155. , side(battle->battleGetMySide())
  156. , nullstack(InitNullStack())
  157. {
  158. auto [lv, lh, rv, rh] = CalcGlobalStats(battle);
  159. gstats = std::make_unique<GlobalStats>(battle->battleGetMySide(), lv + rv, lh + rh);
  160. lpstats = std::make_unique<PlayerStats>(BattleSide::LEFT_SIDE, lv, lh);
  161. rpstats = std::make_unique<PlayerStats>(BattleSide::RIGHT_SIDE, rv, rh);
  162. battlefield = Battlefield::Create(battle, nullptr, gstats.get(), gstats.get(), sstats, false);
  163. bfstate.reserve(S13::BATTLEFIELD_STATE_SIZE);
  164. actmask.reserve(S13::N_ACTIONS);
  165. }
  166. void State::onActiveStack(const CStack * astack, CombatResult result, bool recording, bool fastpath)
  167. {
  168. logAi->debug("onActiveStack: result=%d, recording=%d, fastpath=%d", EI(result), recording, fastpath);
  169. const auto & [lv, lh, rv, rh] = CalcGlobalStats(battle);
  170. const auto & [ldd, ldr, lvk, lvl, rdd, rdr, rvk, rvl] = ProcessAttackLogs(attackLogs, sstats);
  171. auto ogstats = *gstats; // a copy of the "old" gstats
  172. (result == CombatResult::NONE) ? gstats->update(astack->unitSide(), result, lv + rv, lh + rh, !astack->waitedThisTurn)
  173. : gstats->update(BattleSide::NONE, result, lv + rv, lh + rh, false);
  174. lpstats->update(&ogstats, lv, lh, ldd, ldr, lvk, lvl);
  175. rpstats->update(&ogstats, rv, rh, rdd, rdr, rvk, rvl);
  176. if(fastpath)
  177. {
  178. // means we are done with onActiveStack, and we can safely clear transitions now
  179. transitions.clear();
  180. persistentAttackLogs.clear();
  181. }
  182. else
  183. {
  184. if(enableTransitions)
  185. persistentAttackLogs.insert(persistentAttackLogs.end(), attackLogs.begin(), attackLogs.end());
  186. battlefield = Battlefield::Create(battle, astack, &ogstats, gstats.get(), sstats, isMorale);
  187. bfstate.clear();
  188. actmask.clear();
  189. for(int i = 0; i < EI(GlobalAction::_count); i++)
  190. {
  191. switch(static_cast<GlobalAction>(i))
  192. {
  193. case GlobalAction::RETREAT:
  194. actmask.push_back(battle->battleCanFlee());
  195. break;
  196. case GlobalAction::WAIT:
  197. actmask.push_back(battlefield->astack && !battlefield->astack->cstack->waitedThisTurn);
  198. break;
  199. default:
  200. THROW_FORMAT("Unexpected GlobalAction: %d", i);
  201. }
  202. }
  203. encodeGlobal(result);
  204. encodePlayer(lpstats.get());
  205. encodePlayer(rpstats.get());
  206. for(auto & hexrow : *battlefield->hexes)
  207. for(auto & hex : hexrow)
  208. encodeHex(hex.get());
  209. // Links are not part of the state
  210. // They are handled separately by the connector
  211. // for (auto &link : battlefield->links)
  212. // encodeLink(link);
  213. verify();
  214. }
  215. isMorale = false;
  216. supdata = std::make_unique<SupplementaryData>(
  217. colorname,
  218. static_cast<Side>(side),
  219. gstats.get(),
  220. lpstats.get(),
  221. rpstats.get(),
  222. battlefield.get(),
  223. enableTransitions ? persistentAttackLogs : attackLogs, // store the logs since OUR last turn
  224. transitions, // store the states since last turn
  225. result
  226. );
  227. if(recording)
  228. {
  229. ASSERT(startedAction >= 0, "unexpected startedAction: " + std::to_string(startedAction));
  230. // NOTE: this creates a copy of bfstate (which is what we want)
  231. transitions.emplace_back(startedAction, std::make_shared<Schema::ActionMask>(actmask), std::make_shared<Schema::BattlefieldState>(bfstate));
  232. }
  233. else
  234. {
  235. actingStack = astack; // for fastpath, see onActionStarted
  236. startedAction = -1;
  237. // XXX: must NOT clear transitions here (can do it only after BAI's activeStack completes)
  238. // transitions.clear();
  239. }
  240. attackLogs.clear(); // accumulate new logs until next turn
  241. }
  242. void State::_onActionStarted(const BattleAction & ba)
  243. {
  244. if(!ba.isUnitAction())
  245. {
  246. logAi->warn("Got non-unit action of type: %d", EI(ba.actionType));
  247. return;
  248. }
  249. auto stacks = battle->battleGetStacks();
  250. // Case A: << ENEMY TURN >>
  251. // 1. StupidAI makes action; vcmi calls ->
  252. // 2. State::onActionStart() calls -> // actingStack is nullptr
  253. // 3. onActiveStack(recording=true) builds bf and returns to ->
  254. // 4. State::onActionStart() clears actingStack
  255. //
  256. // Case B: << OUR TURN >>
  257. // 1. BAI::activeStack() calls ->
  258. // 2. State::onActiveStack(recording=false) builds bf, sets actingStack and returns to ->
  259. // 3. BAI::activeStack() makes action; vcmi calls ->
  260. // 4. State::onActionStart() sets fastpath=true and calls -> // actingStack already present
  261. // 5. onActiveStack(recording=true) **skips building bf** and returns to ->
  262. // 6. State::onActionStart() clears actingStack
  263. //
  264. // no need to create battlefield in 5, as it's the same as in 2.
  265. bool fastpath = false;
  266. bool found = false;
  267. for(const auto * cstack : battle->battleGetAllStacks(true))
  268. {
  269. if(cstack->unitId() == ba.stackNumber)
  270. {
  271. if(actingStack)
  272. {
  273. // XXX: actingStack is already set here only if it was set in onActiveStack() i.e. on our turn
  274. // We could check only the unit's side, but since there are
  275. // auto-acting units, comparing the exact unit seems safer.
  276. fastpath = true;
  277. if(cstack != actingStack)
  278. {
  279. THROW_FORMAT(
  280. "actingStack was already set to %s, but does not match the real acting stack %s",
  281. actingStack->getDescription() % cstack->getDescription()
  282. );
  283. }
  284. }
  285. actingStack = cstack;
  286. found = true;
  287. break;
  288. }
  289. }
  290. ASSERT(found, "could not find cstack with unitId: " + std::to_string(ba.stackNumber));
  291. if(actingStack->creatureId() == CreatureID::FIRST_AID_TENT || actingStack->creatureId() == CreatureID::CATAPULT
  292. || actingStack->creatureId() == CreatureID::ARROW_TOWERS)
  293. {
  294. // These are auto-acting for BAI
  295. // Cannot build state in this case
  296. return;
  297. }
  298. switch(ba.actionType)
  299. {
  300. case EActionType::WAIT:
  301. startedAction = S13::ACTION_WAIT;
  302. break;
  303. case EActionType::SHOOT:
  304. {
  305. auto bh = ba.target.at(0).hexValue;
  306. auto id = Hex::CalcId(bh);
  307. startedAction = S13::N_NONHEX_ACTIONS + id * EI(HexAction::_count) + EI(HexAction::SHOOT);
  308. }
  309. break;
  310. case EActionType::DEFEND:
  311. {
  312. auto bh = actingStack->getPosition();
  313. auto id = Hex::CalcId(bh);
  314. startedAction = S13::N_NONHEX_ACTIONS + id * EI(HexAction::_count) + EI(HexAction::MOVE);
  315. }
  316. break;
  317. case EActionType::WALK:
  318. {
  319. auto bh = ba.target.at(0).hexValue;
  320. auto id = Hex::CalcId(bh);
  321. startedAction = S13::N_NONHEX_ACTIONS + id * EI(HexAction::_count) + EI(HexAction::MOVE);
  322. }
  323. break;
  324. case EActionType::WALK_AND_ATTACK:
  325. {
  326. auto bhMove = ba.target.at(0).hexValue;
  327. auto bhTarget = ba.target.at(1).hexValue;
  328. auto idMove = Hex::CalcId(bhMove);
  329. // Can't use `battlefield` (old state)
  330. auto it = std::ranges::find_if(
  331. stacks,
  332. [&bhTarget](const CStack * cstack)
  333. {
  334. return cstack->coversPos(bhTarget);
  335. }
  336. );
  337. if(it == stacks.end())
  338. {
  339. THROW_FORMAT("Could not find stack for target bhex: %d", bhTarget.toInt());
  340. }
  341. const auto * targetStack = *it;
  342. if(!CStack::isMeleeAttackPossible(actingStack, targetStack, bhMove))
  343. {
  344. THROW_FORMAT("Melee attack not possible from bh=%d to bh=%d (to %s)", bhMove.toInt() % bhTarget.toInt() % targetStack->getDescription());
  345. }
  346. const auto & nbhexes = Hex::NearbyBattleHexes(bhMove);
  347. for(int i = 0; i < nbhexes.size(); ++i)
  348. {
  349. const auto & n_bhex = nbhexes.at(i);
  350. if(n_bhex == bhTarget)
  351. {
  352. startedAction = S13::N_NONHEX_ACTIONS + idMove * EI(HexAction::_count) + i;
  353. break;
  354. }
  355. }
  356. ASSERT(
  357. startedAction >= 0, "failed to determine startedAction"
  358. );
  359. }
  360. break;
  361. case EActionType::MONSTER_SPELL:
  362. logAi->warn("Got MONSTER_SPELL action (use cursed ground to prevent this)");
  363. return;
  364. break;
  365. default:
  366. // Don't record a state diff for the other actions
  367. // (most are irrelevant or should never occur during training,
  368. // except for MONSTER_SPELL, which can be fixed via cursed ground)
  369. logAi->debug("Not recording actionType=%d", EI(ba.actionType));
  370. return;
  371. }
  372. logAi->debug("Recording actionType=%d", EI(ba.actionType));
  373. onActiveStack(actingStack, CombatResult::NONE, true, fastpath);
  374. }
  375. void State::encodeGlobal(CombatResult result)
  376. {
  377. for(int i = 0; i < EI(GA::_count); ++i)
  378. {
  379. Encoder::Encode(static_cast<GA>(i), gstats->attrs.at(i), bfstate);
  380. }
  381. }
  382. void State::encodePlayer(const PlayerStats * pstats)
  383. {
  384. for(int i = 0; i < EI(PA::_count); ++i)
  385. {
  386. Encoder::Encode(static_cast<PA>(i), pstats->attrs.at(i), bfstate);
  387. }
  388. }
  389. void State::encodeHex(const Hex * hex)
  390. {
  391. // Battlefield state
  392. for(int i = 0; i < EI(HA::_count); ++i)
  393. Encoder::Encode(static_cast<HA>(i), hex->attrs.at(i), bfstate);
  394. // Action mask
  395. for(int m = 0; m < hex->actmask.size(); ++m)
  396. actmask.push_back(hex->actmask.test(m));
  397. }
  398. void State::verify() const
  399. {
  400. ASSERT(bfstate.size() == S13::BATTLEFIELD_STATE_SIZE, "unexpected bfstate.size(): " + std::to_string(bfstate.size()));
  401. ASSERT(actmask.size() == N_ACTIONS, "unexpected actmask.size(): " + std::to_string(actmask.size()));
  402. }
  403. void State::onBattleStacksAttacked(const std::vector<BattleStackAttacked> & bsa)
  404. {
  405. auto stacks = battlefield->stacks;
  406. for(const auto & elem : bsa)
  407. {
  408. const auto * cdefender = battle->battleGetStackByID(elem.stackAttacked, false);
  409. const auto * cattacker = battle->battleGetStackByID(elem.attackerID, false);
  410. ASSERT(cdefender, "defender cannot be NULL");
  411. const auto defender = std::ranges::find_if(
  412. stacks,
  413. [&cdefender](const std::shared_ptr<Stack> & stack)
  414. {
  415. return cdefender == stack->cstack;
  416. }
  417. );
  418. if(defender == stacks.end())
  419. {
  420. logAi->info("defender cstack '%s' not found in stacks. Maybe it was just summoned/resurrected?", cdefender->getDescription());
  421. }
  422. const auto attacker = std::ranges::find_if(
  423. stacks,
  424. [&cattacker](const std::shared_ptr<Stack> & stack)
  425. {
  426. return cattacker == stack->cstack;
  427. }
  428. );
  429. auto bf_valueNow = gstats->attr(GA::BFIELD_VALUE_NOW_ABS);
  430. auto bf_hpNow = gstats->attr(GA::BFIELD_HP_NOW_ABS);
  431. auto value = elem.killedAmount * Stack::GetValue(cdefender->unitType());
  432. // XXX: attacker can be NULL when an effect does dmg (eg. Acid)
  433. // XXX: attacker or defender can be NULL if it did not exist
  434. // when `stacks` was built (e.g. during our last turn),
  435. // Can happen if the enemy has now summonned/resurrected it.
  436. auto ald = AttackLogData{
  437. .attacker = (attacker != stacks.end() ? *attacker : nullptr),
  438. .defender = (defender != stacks.end() ? *defender : nullptr),
  439. .cattacker = cattacker,
  440. .cdefender = cdefender,
  441. .dmg = static_cast<int>(elem.damageAmount),
  442. .dmgPermille = static_cast<int>(1000 * elem.damageAmount / bf_hpNow),
  443. .units = static_cast<int>(elem.killedAmount),
  444. .value = static_cast<int>(value),
  445. .valuePermille = static_cast<int>(1000 * value / bf_valueNow)
  446. };
  447. attackLogs.push_back(std::make_shared<AttackLog>(std::move(ald)));
  448. }
  449. }
  450. void State::onBattleTriggerEffect(const BattleTriggerEffect & bte)
  451. {
  452. if(bte.effect != BonusType::MORALE)
  453. return;
  454. isMorale = true;
  455. }
  456. void State::onActionFinished(const BattleAction & ba)
  457. {
  458. // XXX: assuming action was OK (no server error about failed/fishy action)
  459. }
  460. /*
  461. * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  462. * !!!!!! IMPORTANT: `battlefield` must not be used here (old state) !!!!!!
  463. * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  464. */
  465. void State::onActionStarted(const BattleAction & ba)
  466. {
  467. if(!enableTransitions)
  468. return;
  469. _onActionStarted(ba);
  470. actingStack = nullptr;
  471. }
  472. void State::onBattleEnd(const BattleResult * br)
  473. {
  474. switch(br->winner)
  475. {
  476. case BattleSide::LEFT_SIDE:
  477. onActiveStack(nullptr, CombatResult::LEFT_WINS);
  478. break;
  479. case BattleSide::RIGHT_SIDE:
  480. onActiveStack(nullptr, CombatResult::RIGHT_WINS);
  481. break;
  482. default:
  483. onActiveStack(nullptr, CombatResult::DRAW);
  484. }
  485. }
  486. };