factory.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /*
  2. * factory.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "factory.h"
  11. #include "callback/CBattleGameInterface.h"
  12. #include "filesystem/Filesystem.h"
  13. #include <onnxruntime_c_api.h>
  14. #include "BAI/v13/BAI.h"
  15. #include "BAI/v13/nn_model.h"
  16. namespace MMAI::BAI
  17. {
  18. namespace
  19. {
  20. std::unique_ptr<Ort::Session> load(const std::string & path)
  21. {
  22. /*
  23. * IMPORTANT:
  24. * There seems to be an UB in the model unless either of the below is set:
  25. * a) GraphOptimizationLevel::ORT_DISABLE_ALL
  26. * b) DisableMemPattern
  27. *
  28. * Mem pattern does not impact performance => disable.
  29. * Graph optimization causes < 30% speedup => not worth the risk, disable.
  30. *
  31. */
  32. auto opts = Ort::SessionOptions();
  33. opts.DisableMemPattern();
  34. opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
  35. opts.SetExecutionMode(ORT_SEQUENTIAL); // ORT_SEQUENTIAL = no inter-op parallelism
  36. opts.SetInterOpNumThreads(1); // Inter-op threads matter in ORT_PARALLEL
  37. opts.SetIntraOpNumThreads(4); // Parallelism inside kernels/operators
  38. static const auto env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "vcmi"};
  39. const auto rpath = ResourcePath(path, EResType::AI_MODEL);
  40. const auto * rhandler = CResourceHandler::get();
  41. if(!rhandler->existsResource(rpath))
  42. throw std::runtime_error("NNBase: resource does not exist: " + rpath.getName());
  43. const auto & [data, length] = rhandler->load(rpath)->readAll();
  44. return std::make_unique<Ort::Session>(env, data.get(), length, opts);
  45. }
  46. int readVersion(Ort::Session * session, OrtAllocator * allocator, const Ort::ModelMetadata & md)
  47. {
  48. /*
  49. * version
  50. * dtype=int
  51. * shape=scalar
  52. *
  53. * Version of the model (current implementation is at version 13).
  54. * If needed, NNModel may be extended to support other versions as well.
  55. *
  56. */
  57. int res = -1;
  58. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("version", allocator);
  59. if(!v)
  60. throw std::runtime_error("NNBase: readVersion: no such key");
  61. std::string vs(v.get());
  62. try
  63. {
  64. res = std::stoi(vs);
  65. }
  66. catch(...)
  67. {
  68. throw std::runtime_error("NNBase: readVersion: not an int: " + vs);
  69. }
  70. return res;
  71. }
  72. std::shared_ptr<NNContainer> CreateNNContainer(const std::string & path)
  73. {
  74. auto session = load(path);
  75. auto meminfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
  76. auto metadata = session->GetModelMetadata();
  77. auto allocator = Ort::AllocatorWithDefaultOptions();
  78. auto version = readVersion(session.get(), allocator, metadata);
  79. return std::make_shared<NNContainer>(path, std::move(session), std::move(meminfo), std::move(metadata), std::move(allocator), version);
  80. }
  81. }
  82. // Factory method for versioned derived NNModel (e.g. NNModel::V1)
  83. std::shared_ptr<MMAI::Schema::IModel> CreateNNModel(const std::string & path, float temperature, uint64_t seed)
  84. {
  85. auto container = CreateNNContainer(path);
  86. if(container->version == 13)
  87. return std::make_shared<V13::NNModel>(container, temperature, seed);
  88. else
  89. throw std::runtime_error("CreateNNModel: unsupported schema version: " + std::to_string(container->version));
  90. }
  91. // Factory method for versioned derived BAI (e.g. BAI::V1)
  92. std::shared_ptr<CBattleGameInterface>
  93. CreateBAI(Schema::IModel * model, const std::shared_ptr<Environment> & env, const std::shared_ptr<CBattleCallback> & cb, bool enableSpellsUsage)
  94. {
  95. std::shared_ptr<CBattleGameInterface> res;
  96. auto version = model->getVersion();
  97. if(version == 13)
  98. return std::make_shared<V13::BAI>(model, version, env, cb, enableSpellsUsage);
  99. else
  100. throw std::runtime_error("CreateBAI: unsupported schema version: " + std::to_string(version));
  101. return res;
  102. }
  103. }