NNModel.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. /*
  2. * NNModel.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BAI/model/util/bucketing.h"
  12. #include "BAI/model/util/common.h"
  13. #include "BAI/model/util/sampling.h"
  14. #include "NNModel.h"
  15. #include "filesystem/Filesystem.h"
  16. #include "vstd/CLoggerBase.h"
  17. #include "json/JsonNode.h"
  18. #include <algorithm>
  19. #include <onnxruntime_c_api.h>
  20. #include <onnxruntime_cxx_api.h>
  21. namespace MMAI::BAI
  22. {
  23. namespace
  24. {
  25. struct ScopedTimer
  26. {
  27. std::string name;
  28. std::chrono::steady_clock::time_point t0;
  29. explicit ScopedTimer(const std::string & n) : name(n), t0(std::chrono::steady_clock::now()) {}
  30. ScopedTimer(const ScopedTimer &) = delete;
  31. ScopedTimer & operator=(const ScopedTimer &) = delete;
  32. ScopedTimer(ScopedTimer &&) = delete;
  33. ScopedTimer & operator=(ScopedTimer &&) = delete;
  34. ~ScopedTimer()
  35. {
  36. auto dt = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - t0).count();
  37. logAi->info("%s: %lld ms", name, dt);
  38. }
  39. };
  40. std::array<std::vector<int32_t>, 165> buildNeighbourhoods_unpadded(const std::vector<int64_t> & dst)
  41. {
  42. // Validate and count degrees per node
  43. std::array<int, 165> deg{};
  44. for(auto e : dst)
  45. {
  46. auto v = static_cast<int>(e);
  47. if(v < 0 || v >= 165)
  48. throwf("dst contains node id out of range: %d", v);
  49. ++deg[v];
  50. }
  51. std::array<std::vector<int32_t>, 165> res{};
  52. for(int v = 0; v < 165; ++v)
  53. res[v].reserve(deg[v]);
  54. for(size_t e = 0; e < dst.size(); ++e)
  55. {
  56. auto v = static_cast<int>(dst[e]);
  57. res[v].push_back(static_cast<int32_t>(e));
  58. }
  59. return res;
  60. }
  61. }
  62. std::unique_ptr<Ort::Session> NNModel::loadModel(const std::string & path, const Ort::SessionOptions & opts)
  63. {
  64. static const auto env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "vcmi"};
  65. const auto rpath = ResourcePath(path, EResType::AI_MODEL);
  66. const auto * rhandler = CResourceHandler::get();
  67. if(!rhandler->existsResource(rpath))
  68. throwf("resource does not exist: %s", rpath.getName());
  69. const auto & [data, length] = rhandler->load(rpath)->readAll();
  70. return std::make_unique<Ort::Session>(env, data.get(), length, opts);
  71. }
  72. int NNModel::readVersion(const Ort::ModelMetadata & md) const
  73. {
  74. /*
  75. * version
  76. * dtype=int
  77. * shape=scalar
  78. *
  79. * Version of the model (current implementation is at version 13).
  80. * If needed, NNModel may be extended to support other versions as well.
  81. *
  82. */
  83. int res = -1;
  84. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("version", allocator);
  85. if(!v)
  86. throwf("readVersion: no such key");
  87. std::string vs(v.get());
  88. try
  89. {
  90. res = std::stoi(vs);
  91. }
  92. catch(...)
  93. {
  94. throwf("readVersion: not an int: %s", vs);
  95. }
  96. if(res != 13)
  97. throwf("readVersion: want: 13, have: %d (%s)", res, vs);
  98. return res;
  99. }
  100. Schema::Side NNModel::readSide(const Ort::ModelMetadata & md) const
  101. {
  102. /*
  103. * side
  104. * dtype=int
  105. * shape=scalar
  106. *
  107. * Battlefield side the model was trained on (see Schema::Side enum).
  108. *
  109. */
  110. Schema::Side res;
  111. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("side", allocator);
  112. if(!v)
  113. throw std::runtime_error("metadata error: side: no such key");
  114. std::string vs(v.get());
  115. try
  116. {
  117. res = static_cast<Schema::Side>(std::stoi(vs));
  118. }
  119. catch(...)
  120. {
  121. throw std::runtime_error("metadata error: side: not an int");
  122. }
  123. return res;
  124. }
  125. Vec3D<int32_t> NNModel::readBucketSizes(const Ort::ModelMetadata & md) const
  126. {
  127. /*
  128. * all_sizes
  129. * dtype=int
  130. * shape=[5, 7, 2]:
  131. * d1: bucket size (S, M, L, XL, XXL)
  132. * d2: edge type (see Schema::V13::LinkType enum)
  133. * d3: pairs of [Emax, Kmax]:
  134. * Emax = max number of outbound node edges
  135. * Kmax = max number of inbound node edges
  136. *
  137. * Stats (10K steps):
  138. *
  139. * Outbound edges (E) avg max p99 p90 p75 p50 p25
  140. * -----------------------------------------------------------------
  141. * ADJACENT 888 888 888 888 888 888 888
  142. * REACH 355 988 820 614 478 329 209
  143. * RANGED_MOD 408 2403 1285 646 483 322 162
  144. * ACTS_BEFORE 51 268 203 118 75 35 15
  145. * MELEE_DMG_REL 43 198 160 103 60 31 14
  146. * RETAL_DMG_REL 27 165 113 67 38 18 8
  147. * RANGED_DMG_REL 12 133 60 29 18 9 4
  148. *
  149. * Inbound edges (K) avg max p99 p90 p75 p50 p25
  150. * -----------------------------------------------------------------
  151. * ADJACENT 5.4 6 6 6 6 6 6
  152. * REACH 2.2 13 10 8 6 4 3
  153. * RANGED_MOD 2.5 15 8 4 3 2 1
  154. * ACTS_BEFORE 0.3 23 19 15 12 8 5
  155. * MELEE_DMG_REL 0.3 10 9 8 7 5 3
  156. * RETAL_DMG_REL 0.2 10 9 8 6 5 3
  157. * RANGED_DMG_REL 0.1 8 6 3 2 2 1
  158. *
  159. * Approx. sizes are S=p50 / M=p90 / L=p99 / XL=max / XXL=2*max
  160. * Exact values defined in the vcmi-gym project and are subject to change.
  161. * NOTE: bucketed inputs are deprecated and will soon be removed.
  162. *
  163. */
  164. Vec3D<int32_t> res = {};
  165. Ort::AllocatedStringPtr ab = md.LookupCustomMetadataMapAllocated("all_sizes", allocator);
  166. if(!ab)
  167. throw std::runtime_error("metadata key 'all_sizes' missing");
  168. const std::string jsonstr(ab.get());
  169. try
  170. {
  171. auto jn = JsonNode(jsonstr.data(), jsonstr.size(), "<ONNX metadata: all_sizes>");
  172. if(!jn.isVector())
  173. throwf("readBucketSizes: bad JsonType: want: %d, have: %d", EI(JsonNode::JsonType::DATA_VECTOR), EI(jn.getType()));
  174. for(auto & jv0 : jn.Vector())
  175. {
  176. auto vec1 = std::vector<std::vector<int32_t>>{};
  177. for(auto & jv1 : jv0.Vector())
  178. {
  179. auto vec2 = std::vector<int32_t>{};
  180. for(auto & jv2 : jv1.Vector())
  181. {
  182. if(!jv2.isNumber())
  183. {
  184. throwf("readBucketSizes: invalid data type: want: %d, got: %d", EI(JsonNode::JsonType::DATA_INTEGER), EI(jv2.getType()));
  185. }
  186. vec2.push_back(static_cast<int32_t>(jv2.Integer()));
  187. }
  188. vec1.emplace_back(vec2);
  189. }
  190. res.emplace_back(vec1);
  191. }
  192. }
  193. catch(const std::exception & e)
  194. {
  195. throw std::runtime_error(std::string("readBucketSizes: failed to parse JSON: ") + e.what());
  196. }
  197. if(res.size() != 5)
  198. throwf("readBucketSizes: bad size for d1: want: 5, have: %zu", res.size());
  199. if(res[0].size() != 7)
  200. throwf("readBucketSizes: bad size for d2: want: 7, have: %zu", res[0].size());
  201. if(res[0][0].size() != 2)
  202. throwf("readBucketSizes: bad size for d3: want: 2, have: %zu", res[0][0].size());
  203. return res;
  204. }
  205. Vec3D<int32_t> NNModel::readActionTable(const Ort::ModelMetadata & md) const
  206. {
  207. /*
  208. * action_table
  209. * dtype=int
  210. * shape=[4, 165, 165]:
  211. * d1: action (WAIT, MOVE, AMOVE, SHOOT)
  212. * d2: target hex for MOVE, AMOVE (hex to move to) or SHOOT
  213. * d3: target hex for AMOVE (hex to melee-attack at after moving)
  214. *
  215. */
  216. Vec3D<int32_t> res = {};
  217. Ort::AllocatedStringPtr ab = md.LookupCustomMetadataMapAllocated("action_table", allocator);
  218. if(!ab)
  219. throwf("readActionTable: metadata key 'action_table' missing");
  220. const std::string jsonstr(ab.get());
  221. try
  222. {
  223. auto jn = JsonNode(jsonstr.data(), jsonstr.size(), "<ONNX metadata: all_sizes>");
  224. for(auto & jv0 : jn.Vector())
  225. {
  226. auto vec1 = std::vector<std::vector<int32_t>>{};
  227. for(auto & jv1 : jv0.Vector())
  228. {
  229. auto vec2 = std::vector<int32_t>{};
  230. for(auto & jv2 : jv1.Vector())
  231. {
  232. if(!jv2.isNumber())
  233. {
  234. throwf("invalid data type: want: %d, got: %d", EI(JsonNode::JsonType::DATA_INTEGER), EI(jv2.getType()));
  235. }
  236. vec2.push_back(static_cast<int32_t>(jv2.Integer()));
  237. }
  238. vec1.emplace_back(vec2);
  239. }
  240. res.emplace_back(vec1);
  241. }
  242. }
  243. catch(const std::exception & e)
  244. {
  245. throwf(std::string("failed to parse 'action_table' JSON: ") + e.what());
  246. }
  247. if(res.size() != 4)
  248. throwf("readActionTable: bad size for d1: want: 4, have: %zu", res.size());
  249. if(res[0].size() != 165)
  250. throwf("readActionTable: bad size for d2: want: 165, have: %zu", res[0].size());
  251. if(res[0][0].size() != 165)
  252. throwf("readActionTable: bad size for d3: want: 165, have: %zu", res[0][0].size());
  253. return res;
  254. }
  255. bool NNModel::readIsDynamic(const Ort::ModelMetadata & md) const
  256. {
  257. /*
  258. * is_dynamic
  259. * dtype=int
  260. * shape=scalar
  261. *
  262. * Might not be present on older models (return false in this case).
  263. */
  264. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("is_dynamic", allocator);
  265. return v && std::string(v.get()) == "1";
  266. }
  267. std::vector<const char *> NNModel::readInputNames(int want)
  268. {
  269. /*
  270. * Model inputs (4):
  271. * [0] battlefield state
  272. * dtype=float
  273. * shape=[S] where S=Schema::V13::BATTLEFIELD_STATE_SIZE
  274. * [1] edge index
  275. * dtype=int32
  276. * shape=[2, E*] where E is the number of edges
  277. * [2] edge attributes
  278. * dtype=float
  279. * shape=[E*, 1] where E
  280. * [3] node neighbourhoods
  281. * dtype=int
  282. * shape=[165, K*] where K is the max number of inbound edges per hex
  283. * [4] size
  284. * dtype=int
  285. * shape=[7, 2]
  286. */
  287. std::vector<const char *> res;
  288. auto count = model->GetInputCount();
  289. if(count != want)
  290. throwf("wrong input count: want: %d, have: %lld", want, count);
  291. inputNamePtrs.reserve(count);
  292. res.reserve(count);
  293. for(size_t i = 0; i < count; ++i)
  294. {
  295. inputNamePtrs.emplace_back(model->GetInputNameAllocated(i, allocator));
  296. res.push_back(inputNamePtrs.back().get());
  297. }
  298. return res;
  299. }
  300. std::vector<const char *> NNModel::readOutputNames()
  301. {
  302. /*
  303. * Model outputs (10):
  304. * [0] greedy action
  305. * dtype=int
  306. * shape=[1]
  307. * [1] main action logits (see readActionTable, d0)
  308. * dtype=float
  309. * shape=[4]
  310. * [2] hex#1 logits (see readActionTable, d1)
  311. * dtype=float
  312. * shape=[165]
  313. * [3] hex#2 logits (see readActionTable, d2)
  314. * dtype=float
  315. * shape=[165]
  316. * [4] main action mask
  317. * dtype=int
  318. * shape=[4]
  319. * [5] hex#1 mask
  320. * dtype=int
  321. * shape=[165]
  322. * [6] hex#2 mask
  323. * dtype=int
  324. * shape=[165]
  325. * [7] greedy main action
  326. * dtype=int
  327. * shape=[1]
  328. * [8] greedy hex1
  329. * dtype=int
  330. * shape=[1]
  331. * [9] greedy hex2
  332. * dtype=int
  333. * shape=[1]
  334. *
  335. * The greedy output values are unused since their stochastic counterparts
  336. * are sampled here instead (see sampling::sample_triplet).
  337. */
  338. std::vector<const char *> res;
  339. auto count = model->GetOutputCount();
  340. if(count != 10)
  341. throwf("wrong output count: want: %d, have: %lld", count, count);
  342. outputNamePtrs.reserve(count);
  343. res.reserve(count);
  344. for(size_t i = 0; i < count; ++i)
  345. {
  346. outputNamePtrs.emplace_back(model->GetOutputNameAllocated(i, allocator));
  347. res.push_back(outputNamePtrs.back().get());
  348. }
  349. return res;
  350. }
  351. /*
  352. * XXX:
  353. * hex1_logits and hex2_logits are based on a greedy act0.
  354. * However, if temp > 0 and a non-greedy act0 is chosen,
  355. * the hex logits become inconsistent with the chosen action.
  356. * As a temporary workaround, force greedy actions with temperature = 0.
  357. * Proper fix would require:
  358. * 1) re-exporting the model, changing its output dimensions to
  359. * [4, 165] and [4, 165, 165] for hex1_logits and hex2_logits respectively
  360. * 2) changing the logic here to pick the proper hex logits after sampling
  361. */
  362. NNModel::NNModel(const std::string & path, float _temperature, uint64_t seed)
  363. : path(path), temperature(0), meminfo(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault))
  364. {
  365. logAi->info("MMAI: NNModel params: seed=%1%, temperature=%2%, model=%3%", seed, temperature, path);
  366. if(seed == 0)
  367. {
  368. seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
  369. logAi->info("Generated new seed: %1%", seed);
  370. }
  371. rng = std::mt19937(seed);
  372. /*
  373. * IMPORTANT:
  374. * There seems to be an UB in the model unless either (or both):
  375. * a) DisableMemPattern
  376. * b) GraphOptimizationLevel::ORT_DISABLE_ALL
  377. *
  378. * Mem pattern does not impact performance => disable.
  379. * Graph optimization causes < 30% speedup => not worth the risk, disable.
  380. *
  381. */
  382. auto opts = Ort::SessionOptions();
  383. opts.DisableMemPattern();
  384. opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
  385. opts.SetExecutionMode(ORT_SEQUENTIAL); // ORT_SEQUENTIAL = no inter-op parallelism
  386. opts.SetInterOpNumThreads(1); // Inter-op threads matter in ORT_PARALLEL
  387. opts.SetIntraOpNumThreads(4); // Parallelism inside kernels/operators
  388. model = loadModel(path, opts);
  389. auto md = model->GetModelMetadata();
  390. version = readVersion(md);
  391. side = readSide(md);
  392. actionTable = readActionTable(md);
  393. bucketSizes = readBucketSizes(md);
  394. isDynamic = readIsDynamic(md);
  395. inputNames = readInputNames(isDynamic ? 5 : 4);
  396. outputNames = readOutputNames();
  397. logAi->info("MMAI version %d initialized on side=%d (dynamic=%d)", version, EI(side), isDynamic);
  398. }
  399. Schema::ModelType NNModel::getType()
  400. {
  401. return Schema::ModelType::NN;
  402. };
  403. std::string NNModel::getName()
  404. {
  405. return "MMAI_MODEL";
  406. };
  407. int NNModel::getVersion()
  408. {
  409. return version;
  410. };
  411. Schema::Side NNModel::getSide()
  412. {
  413. return side;
  414. };
  415. int NNModel::getAction(const MMAI::Schema::IState * s)
  416. {
  417. auto timer = ScopedTimer("getAction");
  418. auto any = s->getSupplementaryData();
  419. if(s->version() != version)
  420. throwf("getAction: unsupported IState version: want: %d, have: %d", version, s->version());
  421. if(!any.has_value())
  422. throw std::runtime_error("extractSupplementaryData: supdata is empty");
  423. auto err = MMAI::Schema::AnyCastError(any, typeid(const MMAI::Schema::V13::ISupplementaryData *));
  424. if(!err.empty())
  425. throwf("getAction: anycast failed: %s", err);
  426. const auto * sup = std::any_cast<const MMAI::Schema::V13::ISupplementaryData *>(any);
  427. if(sup->getIsBattleEnded())
  428. {
  429. timer.name = boost::str(boost::format("MMAI action: %d (battle ended)") % MMAI::Schema::ACTION_RESET);
  430. return MMAI::Schema::ACTION_RESET;
  431. }
  432. auto inputs = prepareInputsV13(s, sup);
  433. auto outputs = model->Run(Ort::RunOptions(), inputNames.data(), inputs.data(), inputs.size(), outputNames.data(), outputNames.size());
  434. if(outputs.size() != 10)
  435. throwf("getAction: bad output size: want: 10, have: %d", outputs.size());
  436. // Deterministic (greedy) action
  437. auto action = toVector<int32_t>("getAction: t_action", outputs[0], 1).at(0);
  438. timer.name = "MMAI action: " + std::to_string(action);
  439. // Stochastic action (used instead of the greedy action if temperature > 0)
  440. if(temperature > 1e-8)
  441. {
  442. auto sample = sampling::sample_triplet(
  443. MaskedLogits{.logits = outputs[1], .mask = outputs[4]}, // act0 [4]
  444. MaskedLogits{.logits = outputs[2], .mask = outputs[5]}, // hex1 [165]
  445. MaskedLogits{.logits = outputs[3], .mask = outputs[6]}, // hex2 [165]
  446. temperature,
  447. rng
  448. );
  449. auto s_action = actionTable.at(sample.act0).at(sample.hex1).at(sample.hex2);
  450. if(s_action != action)
  451. logAi->debug("Sampled a non-greedy action: %d with confidence=%.2f", s_action, sample.confidence);
  452. timer.name = boost::str(boost::format("MMAI action: %d (confidence=%.2f)") % s_action % sample.confidence);
  453. action = s_action;
  454. }
  455. return static_cast<MMAI::Schema::Action>(action);
  456. };
  457. double NNModel::getValue(const MMAI::Schema::IState * s)
  458. {
  459. // This quantifies how good is the current state as perceived by the model
  460. // (not used, not implemented)
  461. return 0;
  462. }
  463. std::vector<Ort::Value> NNModel::prepareInputsV13(const MMAI::Schema::IState * s, const MMAI::Schema::V13::ISupplementaryData * sup)
  464. {
  465. auto containers = std::array<IndexContainer, LT_COUNT>{};
  466. int count = 0;
  467. for(const auto & [type, links] : sup->getAllLinks())
  468. {
  469. // assert order
  470. if(EI(type) != count)
  471. throwf("unexpected link type: want: %d, have: %d", count, EI(type));
  472. auto & c = containers.at(count);
  473. const auto srcinds = links->getSrcIndex();
  474. const auto dstinds = links->getDstIndex();
  475. const auto attrs = links->getAttributes();
  476. auto nlinks = srcinds.size();
  477. if(dstinds.size() != nlinks)
  478. throwf("unexpected dstinds.size() for LinkType(%d): want: %d, have: %d", EI(type), nlinks, dstinds.size());
  479. if(attrs.size() != nlinks)
  480. throwf("unexpected attrs.size() for LinkType(%d): want: %d, have: %d", EI(type), nlinks, attrs.size());
  481. c.edgeIndex.at(0).reserve(nlinks);
  482. c.edgeIndex.at(1).reserve(nlinks);
  483. c.edgeIndex.at(0).insert(c.edgeIndex.at(0).end(), srcinds.begin(), srcinds.end());
  484. c.edgeIndex.at(1).insert(c.edgeIndex.at(1).end(), dstinds.begin(), dstinds.end());
  485. c.edgeAttrs.reserve(nlinks);
  486. c.edgeAttrs.insert(c.edgeAttrs.end(), attrs.begin(), attrs.end());
  487. c.neighbourhoods = buildNeighbourhoods_unpadded(dstinds);
  488. ++count;
  489. }
  490. if(count != LT_COUNT)
  491. throwf("unexpected links count: want: %d, have: %d", LT_COUNT, count);
  492. auto bdata = bucketing::BucketBuilder(containers, bucketSizes).build_bucket_data(isDynamic);
  493. const auto * state = s->getBattlefieldState();
  494. auto estate = std::vector<float>(state->size());
  495. std::ranges::copy(*state, estate.begin());
  496. int sum_e = bdata.edgeIndex_flat.at(0).size();
  497. int sum_k = bdata.neighbourhoods_flat.at(0).size();
  498. if(bdata.edgeIndex_flat.at(0).size() != sum_e)
  499. throwf("unexpected bdata.edgeIndex_flat.at(0).size(): want: %d, have: %d", sum_e, bdata.edgeIndex_flat.at(0).size());
  500. if(bdata.edgeIndex_flat.at(1).size() != sum_e)
  501. throwf("unexpected bdata.edgeIndex_flat.at(1).size(): want: %d, have: %d", sum_e, bdata.edgeIndex_flat.at(1).size());
  502. if(bdata.edgeAttrs_flat.size() != sum_e)
  503. throwf("unexpected bdata.edgeAttrs_flat.size(): want: %d, have: %d", sum_e, bdata.edgeAttrs_flat.size());
  504. for(int i = 0; i < 165; ++i)
  505. {
  506. if(bdata.neighbourhoods_flat.at(i).size() != sum_k)
  507. throwf("unexpected bdata.neighbourhoods_flat.at(%d).size(): want: %d, have: %d", i, sum_k, bdata.neighbourhoods_flat.at(i).size());
  508. }
  509. auto edgeIndex_flat = std::vector<int32_t>{};
  510. edgeIndex_flat.reserve(2 * sum_e);
  511. for(auto & ei : bdata.edgeIndex_flat)
  512. edgeIndex_flat.insert(edgeIndex_flat.end(), ei.begin(), ei.end());
  513. auto neighbourhoods = std::vector<int32_t>{};
  514. neighbourhoods.reserve(165 * sum_k);
  515. for(auto & nbr : bdata.neighbourhoods_flat)
  516. neighbourhoods.insert(neighbourhoods.end(), nbr.begin(), nbr.end());
  517. auto tensors = std::vector<Ort::Value>{};
  518. tensors.push_back(toTensor("state", estate, {static_cast<int64_t>(estate.size())}));
  519. tensors.push_back(toTensor("edgeIndex_flat", edgeIndex_flat, {2, sum_e}));
  520. tensors.push_back(toTensor("edgeAttrs_flat", bdata.edgeAttrs_flat, {sum_e, 1}));
  521. tensors.push_back(toTensor("nbr_flat", neighbourhoods, {165, sum_k}));
  522. if(isDynamic)
  523. {
  524. auto size = std::vector<int64_t>{};
  525. size.reserve(EI(LT_COUNT) * 2);
  526. for(int i = 0; i < EI(LT_COUNT); ++i)
  527. {
  528. size.push_back(bdata.size.emax.at(i));
  529. size.push_back(bdata.size.kmax.at(i));
  530. }
  531. tensors.push_back(toTensor("size", size, {EI(LT_COUNT), 2}));
  532. }
  533. logAi->debug("Model input shapes: state={%d} edgeIndex={2, %d} edgeAttrs={%d, 1} nbr={165, %d}", estate.size(), sum_e, sum_e, sum_k);
  534. return tensors;
  535. }
  536. template<typename T>
  537. Ort::Value NNModel::toTensor(const std::string & name, std::vector<T> & vec, const std::vector<int64_t> & shape)
  538. {
  539. // Sanity check
  540. int64_t numel = 1;
  541. for(int64_t d : shape)
  542. numel *= d;
  543. if(numel != vec.size())
  544. throwf("toTensor: %s: numel check failed: want: %d, have: %d", name, numel, vec.size());
  545. // Create a memory-owning tensor then copy data
  546. auto res = Ort::Value::CreateTensor<T>(allocator, shape.data(), shape.size());
  547. T * dst = res.template GetTensorMutableData<T>();
  548. std::memcpy(dst, vec.data(), vec.size() * sizeof(T));
  549. return res;
  550. }
  551. } // namespace MMAI::BAI