| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- /*
- * router.cpp, part of VCMI engine
- *
- * Authors: listed in file AUTHORS in main folder
- *
- * License: GNU General Public License v2.0 or later
- * Full text of license available in license.txt file, in main folder
- *
- */
- #include "StdInc.h"
- #include "callback/CBattleCallback.h"
- #include "callback/CDynLibHandler.h"
- #include "callback/IGameInfoCallback.h"
- #include "filesystem/Filesystem.h"
- #include "json/JsonUtils.h"
- #include "BAI/base.h"
- #include "BAI/model/NNModel.h"
- #include "BAI/model/ScriptedModel.h"
- #include "BAI/router.h"
- #include "common.h"
- #include <utility>
- namespace MMAI::BAI
- {
- using ConfigStorage = std::map<std::string, std::string>;
- using ModelStorage = std::map<std::string, std::unique_ptr<NNModel>>;
- namespace
- {
- struct ModelRepository
- {
- ModelStorage models;
- float temperature = 1.0;
- uint64_t seed = 0;
- std::unique_ptr<ScriptedModel> fallbackModel;
- std::string fallbackName;
- };
- std::unique_ptr<ModelRepository> InitModelRepository()
- {
- auto repo = std::make_unique<ModelRepository>();
- auto json = JsonUtils::assembleFromFiles("MMAI/CONFIG/mmai-settings.json");
- if(!json.isStruct())
- {
- logAi->error("Could not load MMAI config. Is MMAI mod enabled?");
- return repo;
- }
- JsonUtils::validate(json, "vcmi:mmaiSettings", "mmai");
- repo->temperature = static_cast<float>(json["temperature"].Float());
- repo->seed = json["seed"].Integer();
- for(const std::string key : {"attacker", "defender"})
- {
- std::string path = "MMAI/models/" + json["models"][key].String();
- // Try loading dynamic models with priority
- // (temporary code for a smooth migration path)
- const auto pos = path.rfind(".onnx");
- if(pos != std::string::npos)
- {
- std::string dynpath = path;
- dynpath.insert(pos, "-dynamic"); // insert right before ".onnx"
- const auto rpath = ResourcePath(dynpath, EResType::AI_MODEL);
- const auto * rhandler = CResourceHandler::get();
- if(rhandler->existsResource(rpath))
- path = dynpath;
- }
- logAi->debug("MMAI: Loading NN %s model from: %s", key, path);
- try
- {
- repo->models.try_emplace(key, std::make_unique<NNModel>(path, repo->temperature, repo->seed));
- }
- catch(std::exception & e)
- {
- logAi->error("MMAI: error loading " + key + ": " + std::string(e.what()));
- }
- }
- auto fallback = json["fallback"].String();
- logAi->debug("MMAI: preparing fallback model: %s", fallback);
- repo->fallbackModel = std::make_unique<ScriptedModel>(fallback);
- repo->fallbackName = fallback;
- return repo;
- }
- Schema::IModel * GetModel(const std::string & key)
- {
- static const auto MODEL_REPO = InitModelRepository();
- auto it = MODEL_REPO->models.find(key);
- if(it == MODEL_REPO->models.end())
- {
- logAi->error("MMAI: no %s model loaded, trying fallback: %s", key, MODEL_REPO->fallbackName);
- ASSERT(MODEL_REPO->fallbackModel, "fallback failed: model is null");
- return MODEL_REPO->fallbackModel.get();
- }
- return it->second.get();
- }
- }
- Router::Router()
- {
- std::ostringstream oss;
- // Store the memory address and include it in logging
- const auto * ptr = static_cast<const void *>(this);
- oss << ptr;
- addrstr = oss.str();
- info("+++ constructor +++"); // log after addrstr is set
- }
- Router::~Router()
- {
- info("--- destructor ---");
- cb->waitTillRealize = wasWaitingForRealize;
- }
- void Router::initBattleInterface(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCallback> CB)
- {
- info("*** initBattleInterface ***");
- env = ENV;
- cb = CB;
- colorname = cb->getPlayerID()->toString();
- wasWaitingForRealize = cb->waitTillRealize;
- cb->waitTillRealize = false;
- bai.reset();
- }
- void Router::initBattleInterface(std::shared_ptr<Environment> ENV, std::shared_ptr<CBattleCallback> CB, AutocombatPreferences prefs)
- {
- autocombatPreferences = prefs;
- initBattleInterface(ENV, CB);
- }
- /*
- * Delegated methods
- */
- void Router::actionFinished(const BattleID & bid, const BattleAction & action)
- {
- bai->actionFinished(bid, action);
- }
- void Router::actionStarted(const BattleID & bid, const BattleAction & action)
- {
- bai->actionStarted(bid, action);
- }
- void Router::activeStack(const BattleID & bid, const CStack * astack)
- {
- bai->activeStack(bid, astack);
- }
- void Router::battleAttack(const BattleID & bid, const BattleAttack * ba)
- {
- bai->battleAttack(bid, ba);
- }
- void Router::battleCatapultAttacked(const BattleID & bid, const CatapultAttack & ca)
- {
- bai->battleCatapultAttacked(bid, ca);
- }
- void Router::battleEnd(const BattleID & bid, const BattleResult * br, QueryID queryID)
- {
- bai->battleEnd(bid, br, queryID);
- }
- void Router::battleGateStateChanged(const BattleID & bid, const EGateState state)
- {
- bai->battleGateStateChanged(bid, state);
- };
- void Router::battleLogMessage(const BattleID & bid, const std::vector<MetaString> & lines)
- {
- bai->battleLogMessage(bid, lines);
- };
- void Router::battleNewRound(const BattleID & bid)
- {
- bai->battleNewRound(bid);
- }
- void Router::battleNewRoundFirst(const BattleID & bid)
- {
- bai->battleNewRoundFirst(bid);
- }
- void Router::battleObstaclesChanged(const BattleID & bid, const std::vector<ObstacleChanges> & obstacles)
- {
- bai->battleObstaclesChanged(bid, obstacles);
- };
- void Router::battleSpellCast(const BattleID & bid, const BattleSpellCast * sc)
- {
- bai->battleSpellCast(bid, sc);
- }
- void Router::battleStackMoved(const BattleID & bid, const CStack * stack, const BattleHexArray & dest, int distance, bool teleport)
- {
- bai->battleStackMoved(bid, stack, dest, distance, teleport);
- }
- void Router::battleStacksAttacked(const BattleID & bid, const std::vector<BattleStackAttacked> & bsa, bool ranged)
- {
- bai->battleStacksAttacked(bid, bsa, ranged);
- }
- void Router::battleStacksEffectsSet(const BattleID & bid, const SetStackEffect & sse)
- {
- bai->battleStacksEffectsSet(bid, sse);
- }
- void Router::battleStart(
- const BattleID & bid,
- const CCreatureSet * army1,
- const CCreatureSet * army2,
- int3 tile,
- const CGHeroInstance * hero1,
- const CGHeroInstance * hero2,
- BattleSide side,
- bool replayAllowed
- )
- {
- Schema::IModel * model;
- const std::string modelkey = side == BattleSide::ATTACKER ? "attacker" : "defender";
- model = GetModel(modelkey);
- auto modelside = model->getSide();
- auto realside = static_cast<Schema::Side>(EI(side));
- if(modelside != realside && modelside != Schema::Side::BOTH)
- logAi->warn("The loaded '%s' model was not trained to play as %s", modelkey, modelkey);
- switch(model->getType())
- {
- case Schema::ModelType::SCRIPTED:
- if(model->getName() == "StupidAI")
- {
- bai = CDynLibHandler::getNewBattleAI("StupidAI");
- bai->initBattleInterface(env, cb, autocombatPreferences);
- }
- else if(model->getName() == "BattleAI")
- {
- bai = CDynLibHandler::getNewBattleAI("BattleAI");
- bai->initBattleInterface(env, cb, autocombatPreferences);
- }
- else
- {
- THROW_FORMAT("Unexpected scripted model name: %s", model->getName());
- }
- break;
- case Schema::ModelType::NN:
- // XXX: must not call initBattleInterface here
- bai = Base::Create(model, env, cb, autocombatPreferences.enableSpellsUsage);
- break;
- default:
- THROW_FORMAT("Unexpected model type: %d", EI(model->getType()));
- }
- bai->battleStart(bid, army1, army2, tile, hero1, hero2, side, replayAllowed);
- }
- void Router::battleTriggerEffect(const BattleID & bid, const BattleTriggerEffect & bte)
- {
- bai->battleTriggerEffect(bid, bte);
- }
- void Router::battleUnitsChanged(const BattleID & bid, const std::vector<UnitChanges> & changes)
- {
- bai->battleUnitsChanged(bid, changes);
- }
- void Router::yourTacticPhase(const BattleID & bid, int distance)
- {
- bai->yourTacticPhase(bid, distance);
- }
- /*
- * private
- */
- void Router::error(const std::string & text) const
- {
- log(ELogLevel::ERROR, text);
- }
- void Router::warn(const std::string & text) const
- {
- log(ELogLevel::WARN, text);
- }
- void Router::info(const std::string & text) const
- {
- log(ELogLevel::INFO, text);
- }
- void Router::debug(const std::string & text) const
- {
- log(ELogLevel::DEBUG, text);
- }
- void Router::trace(const std::string & text) const
- {
- log(ELogLevel::TRACE, text);
- }
- void Router::log(ELogLevel::ELogLevel level, const std::string & text) const
- {
- if(logAi->getEffectiveLevel() <= level)
- logAi->debug("Router-%s [%s] %s", addrstr, colorname, text);
- }
- }
|