router.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. /*
  2. * router.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "CRandomGenerator.h"
  12. #include "callback/CBattleCallback.h"
  13. #include "callback/CDynLibHandler.h"
  14. #include "callback/IGameInfoCallback.h"
  15. #include "filesystem/Filesystem.h"
  16. #include "json/JsonUtils.h"
  17. #include "BAI/base.h"
  18. #include "BAI/model/NNModel.h"
  19. #include "BAI/model/NNModelStochastic.h"
  20. #include "BAI/model/ScriptedModel.h"
  21. #include "BAI/router.h"
  22. #include "common.h"
  23. #include <utility>
  24. namespace MMAI::BAI
  25. {
  26. using ModelStorage = std::map<std::string, std::unique_ptr<Schema::IModel>>;
  27. namespace
  28. {
  29. struct ModelRepository
  30. {
  31. ModelStorage models;
  32. float temperature = 1.0;
  33. uint64_t seed = 0;
  34. std::unique_ptr<ScriptedModel> fallbackModel;
  35. std::string fallbackName;
  36. };
  37. std::unique_ptr<ModelRepository> InitModelRepository()
  38. {
  39. auto repo = std::make_unique<ModelRepository>();
  40. auto json = JsonUtils::assembleFromFiles("MMAI/CONFIG/mmai-settings.json");
  41. if(!json.isStruct())
  42. {
  43. logAi->error("Could not load MMAI config. Is MMAI mod enabled?");
  44. return repo;
  45. }
  46. JsonUtils::validate(json, "vcmi:mmaiSettings", "mmai");
  47. repo->temperature = static_cast<float>(json["temperature"].Float());
  48. repo->seed = json["seed"].Integer();
  49. if(repo->seed == 0)
  50. repo->seed = CRandomGenerator::getDefault().nextInt();
  51. for(const std::string key : {"attacker", "defender"})
  52. {
  53. std::string path = "MMAI/models/" + json["models"][key].String();
  54. // Try loading stochastic and dynamic models with priority
  55. // (temporary code for a smooth migration path)
  56. std::string suffix;
  57. const auto pos = path.rfind(".onnx");
  58. if(pos != std::string::npos)
  59. {
  60. for(const std::string s : {"stochastic", "dynamic"})
  61. {
  62. std::string altpath = path;
  63. altpath.insert(pos, "-" + s); // insert right before ".onnx"
  64. const auto rpath = ResourcePath(altpath, EResType::AI_MODEL);
  65. const auto * rhandler = CResourceHandler::get();
  66. if(rhandler->existsResource(rpath))
  67. {
  68. path = altpath;
  69. suffix = s;
  70. break;
  71. }
  72. }
  73. }
  74. logAi->debug("MMAI: Loading NN %s model from: %s", key, path);
  75. try
  76. {
  77. // Only stochastic models use a separate class
  78. if(suffix == "stochastic")
  79. repo->models.try_emplace(key, std::make_unique<NNModelStochastic>(path, repo->temperature, repo->seed));
  80. else
  81. repo->models.try_emplace(key, std::make_unique<NNModel>(path, repo->temperature, repo->seed));
  82. }
  83. catch(std::exception & e)
  84. {
  85. logAi->error("MMAI: error loading " + key + ": " + std::string(e.what()));
  86. }
  87. }
  88. auto fallback = json["fallback"].String();
  89. logAi->debug("MMAI: preparing fallback model: %s", fallback);
  90. repo->fallbackModel = std::make_unique<ScriptedModel>(fallback);
  91. repo->fallbackName = fallback;
  92. return repo;
  93. }
  94. Schema::IModel * GetModel(const std::string & key)
  95. {
  96. static const auto MODEL_REPO = InitModelRepository();
  97. auto it = MODEL_REPO->models.find(key);
  98. if(it == MODEL_REPO->models.end())
  99. {
  100. logAi->error("MMAI: no %s model loaded, trying fallback: %s", key, MODEL_REPO->fallbackName);
  101. ASSERT(MODEL_REPO->fallbackModel, "fallback failed: model is null");
  102. return MODEL_REPO->fallbackModel.get();
  103. }
  104. return it->second.get();
  105. }
  106. }
  107. Router::Router()
  108. {
  109. std::ostringstream oss;
  110. // Store the memory address and include it in logging
  111. const auto * ptr = static_cast<const void *>(this);
  112. oss << ptr;
  113. addrstr = oss.str();
  114. info("+++ constructor +++"); // log after addrstr is set
  115. }
  116. Router::~Router()
  117. {
  118. info("--- destructor ---");
  119. cb->waitTillRealize = wasWaitingForRealize;
  120. }
  121. void Router::initBattleInterface(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCallback> CB)
  122. {
  123. info("*** initBattleInterface ***");
  124. env = ENV;
  125. cb = CB;
  126. colorname = cb->getPlayerID()->toString();
  127. wasWaitingForRealize = cb->waitTillRealize;
  128. cb->waitTillRealize = false;
  129. bai.reset();
  130. }
  131. void Router::initBattleInterface(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCallback> CB, AutocombatPreferences prefs)
  132. {
  133. autocombatPreferences = prefs;
  134. initBattleInterface(ENV, CB);
  135. }
  136. /*
  137. * Delegated methods
  138. */
  139. void Router::actionFinished(const BattleID & bid, const BattleAction & action)
  140. {
  141. bai->actionFinished(bid, action);
  142. }
  143. void Router::actionStarted(const BattleID & bid, const BattleAction & action)
  144. {
  145. bai->actionStarted(bid, action);
  146. }
  147. void Router::activeStack(const BattleID & bid, const CStack * astack)
  148. {
  149. bai->activeStack(bid, astack);
  150. }
  151. void Router::battleAttack(const BattleID & bid, const BattleAttack * ba)
  152. {
  153. bai->battleAttack(bid, ba);
  154. }
  155. void Router::battleCatapultAttacked(const BattleID & bid, const CatapultAttack & ca)
  156. {
  157. bai->battleCatapultAttacked(bid, ca);
  158. }
  159. void Router::battleEnd(const BattleID & bid, const BattleResult * br, QueryID queryID)
  160. {
  161. bai->battleEnd(bid, br, queryID);
  162. }
  163. void Router::battleGateStateChanged(const BattleID & bid, const EGateState state)
  164. {
  165. bai->battleGateStateChanged(bid, state);
  166. };
  167. void Router::battleLogMessage(const BattleID & bid, const std::vector<MetaString> & lines)
  168. {
  169. bai->battleLogMessage(bid, lines);
  170. };
  171. void Router::battleNewRound(const BattleID & bid)
  172. {
  173. bai->battleNewRound(bid);
  174. }
  175. void Router::battleNewRoundFirst(const BattleID & bid)
  176. {
  177. bai->battleNewRoundFirst(bid);
  178. }
  179. void Router::battleObstaclesChanged(const BattleID & bid, const std::vector<ObstacleChanges> & obstacles)
  180. {
  181. bai->battleObstaclesChanged(bid, obstacles);
  182. };
  183. void Router::battleSpellCast(const BattleID & bid, const BattleSpellCast * sc)
  184. {
  185. bai->battleSpellCast(bid, sc);
  186. }
  187. void Router::battleStackMoved(const BattleID & bid, const CStack * stack, const BattleHexArray & dest, int distance, bool teleport)
  188. {
  189. bai->battleStackMoved(bid, stack, dest, distance, teleport);
  190. }
  191. void Router::battleStacksAttacked(const BattleID & bid, const std::vector<BattleStackAttacked> & bsa, bool ranged)
  192. {
  193. bai->battleStacksAttacked(bid, bsa, ranged);
  194. }
  195. void Router::battleStacksEffectsSet(const BattleID & bid, const SetStackEffect & sse)
  196. {
  197. bai->battleStacksEffectsSet(bid, sse);
  198. }
  199. void Router::battleStart(
  200. const BattleID & bid,
  201. const CCreatureSet * army1,
  202. const CCreatureSet * army2,
  203. int3 tile,
  204. const CGHeroInstance * hero1,
  205. const CGHeroInstance * hero2,
  206. BattleSide side,
  207. bool replayAllowed
  208. )
  209. {
  210. Schema::IModel * model;
  211. const std::string modelkey = side == BattleSide::ATTACKER ? "attacker" : "defender";
  212. model = GetModel(modelkey);
  213. auto modelside = model->getSide();
  214. auto realside = static_cast<Schema::Side>(EI(side));
  215. if(modelside != realside && modelside != Schema::Side::BOTH)
  216. logAi->warn("The loaded '%s' model was not trained to play as %s", modelkey, modelkey);
  217. switch(model->getType())
  218. {
  219. case Schema::ModelType::SCRIPTED:
  220. if(model->getName() == "StupidAI")
  221. {
  222. bai = CDynLibHandler::getNewBattleAI("StupidAI");
  223. bai->initBattleInterface(env, cb, autocombatPreferences);
  224. }
  225. else if(model->getName() == "BattleAI")
  226. {
  227. bai = CDynLibHandler::getNewBattleAI("BattleAI");
  228. bai->initBattleInterface(env, cb, autocombatPreferences);
  229. }
  230. else
  231. {
  232. THROW_FORMAT("Unexpected scripted model name: %s", model->getName());
  233. }
  234. break;
  235. case Schema::ModelType::NN:
  236. // XXX: must not call initBattleInterface here
  237. bai = Base::Create(model, env, cb, autocombatPreferences.enableSpellsUsage);
  238. break;
  239. default:
  240. THROW_FORMAT("Unexpected model type: %d", EI(model->getType()));
  241. }
  242. bai->battleStart(bid, army1, army2, tile, hero1, hero2, side, replayAllowed);
  243. }
  244. void Router::battleTriggerEffect(const BattleID & bid, const BattleTriggerEffect & bte)
  245. {
  246. bai->battleTriggerEffect(bid, bte);
  247. }
  248. void Router::battleUnitsChanged(const BattleID & bid, const std::vector<UnitChanges> & changes)
  249. {
  250. bai->battleUnitsChanged(bid, changes);
  251. }
  252. void Router::yourTacticPhase(const BattleID & bid, int distance)
  253. {
  254. bai->yourTacticPhase(bid, distance);
  255. }
  256. /*
  257. * private
  258. */
  259. void Router::error(const std::string & text) const
  260. {
  261. log(ELogLevel::ERROR, text);
  262. }
  263. void Router::warn(const std::string & text) const
  264. {
  265. log(ELogLevel::WARN, text);
  266. }
  267. void Router::info(const std::string & text) const
  268. {
  269. log(ELogLevel::INFO, text);
  270. }
  271. void Router::debug(const std::string & text) const
  272. {
  273. log(ELogLevel::DEBUG, text);
  274. }
  275. void Router::trace(const std::string & text) const
  276. {
  277. log(ELogLevel::TRACE, text);
  278. }
  279. void Router::log(ELogLevel::ELogLevel level, const std::string & text) const
  280. {
  281. if(logAi->getEffectiveLevel() <= level)
  282. logAi->debug("Router-%s [%s] %s", addrstr, colorname, text);
  283. }
  284. }