NNModelStochastic.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. /*
  2. * NNModelStochastic.cpp, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #include "StdInc.h"
  11. #include "BAI/model/util/bucketing.h"
  12. #include "BAI/model/util/common.h"
  13. #include "BAI/model/util/sampling.h"
  14. #include "NNModelStochastic.h"
  15. #include "filesystem/Filesystem.h"
  16. #include "vstd/CLoggerBase.h"
  17. #include "json/JsonNode.h"
  18. #include <algorithm>
  19. #include <onnxruntime_c_api.h>
  20. #include <onnxruntime_cxx_api.h>
  21. namespace MMAI::BAI
  22. {
  23. namespace
  24. {
  25. template<typename T>
  26. void assertValidTensor(const std::string & name, const Ort::Value & tensor, int ndim)
  27. {
  28. auto type_info = tensor.GetTensorTypeAndShapeInfo();
  29. auto shape = type_info.GetShape();
  30. auto dtype = type_info.GetElementType();
  31. if(shape.size() != ndim)
  32. throwf("assertValidTensor: %s: bad ndim: want: %d, have: %d", name, ndim, shape.size());
  33. if constexpr(std::is_same_v<T, float>)
  34. {
  35. if(dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
  36. throwf("assertValidTensor: %s: bad dtype: want: %d, have: %d", name, EI(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT), EI(dtype));
  37. }
  38. else if constexpr(std::is_same_v<T, int>)
  39. {
  40. if(dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
  41. throwf("assertValidTensor: %s: bad dtype: want: %d, have: %d", name, EI(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32), EI(dtype));
  42. }
  43. else if constexpr(std::is_same_v<T, int64_t>)
  44. {
  45. if(dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
  46. throwf("assertValidTensor: %s: bad dtype: want: %d, have: %d", name, EI(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64), EI(dtype));
  47. }
  48. else if constexpr(std::is_same_v<T, bool>)
  49. {
  50. if(dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
  51. throwf("assertValidTensor: %s: bad dtype: want: %d, have: %d", name, EI(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL), EI(dtype));
  52. }
  53. else
  54. {
  55. throwf("assertValidTensor: %s: can only work with bool, int and float", name);
  56. }
  57. }
  58. template<typename T>
  59. std::vector<T> toVec1D(const std::string & name, const Ort::Value & tensor, int numel)
  60. {
  61. assertValidTensor<T>(name, tensor, 1);
  62. auto type_info = tensor.GetTensorTypeAndShapeInfo();
  63. auto shape = type_info.GetShape();
  64. if(shape.at(0) != numel)
  65. throwf("toVec1D: %s: bad numel: want: %d, have: %d", name, numel, shape.at(0));
  66. const T * data = tensor.GetTensorData<T>();
  67. auto res = std::vector<T>{};
  68. res.reserve(numel);
  69. res.assign(data, data + numel); // v now owns a copy
  70. return res;
  71. }
  72. template<typename T>
  73. Vec2D<T> toVec2D(const std::string & name, const Ort::Value & tensor, const std::pair<int64_t, int64_t> & dims)
  74. {
  75. assertValidTensor<T>(name, tensor, 2);
  76. const auto & [d0, d1] = dims;
  77. auto type_info = tensor.GetTensorTypeAndShapeInfo();
  78. auto shape = type_info.GetShape();
  79. if(shape.at(0) != d0)
  80. throwf("toVec2D: %s: bad dim0: want: %d, have: %d", name, d0, shape.at(0));
  81. if(shape.at(1) != d1)
  82. throwf("toVec2D: %s: bad dim1: want: %d, have: %d", name, d1, shape.at(1));
  83. const T * data = tensor.GetTensorData<T>();
  84. auto res = Vec2D<T>{};
  85. res.resize(static_cast<size_t>(d0));
  86. for(auto i = 0; i < d0; ++i)
  87. {
  88. auto & row = res[i];
  89. row.resize(d1);
  90. std::memcpy(row.data(), data + i * d1, static_cast<size_t>(d1) * sizeof(T));
  91. }
  92. return res;
  93. }
  94. struct Sample
  95. {
  96. int index;
  97. double confidence;
  98. double prob; // original (non-tempered) probability
  99. };
  100. std::pair<Sample, Sample> categorical(const std::vector<float> & probs, float temperature, std::mt19937 & rng)
  101. {
  102. auto sample = Sample{};
  103. auto greedy = Sample{};
  104. if(temperature < 0.0f)
  105. throwf("sample: negative temperature");
  106. // Greedy sample: argmax, first tie.
  107. {
  108. int best = 0;
  109. for(int i = 0; i < probs.size(); ++i)
  110. if(probs[i] > probs[best])
  111. best = i; // '>' keeps the first tie
  112. greedy.index = best;
  113. greedy.prob = probs[best];
  114. greedy.confidence = 1.0f;
  115. }
  116. if(temperature < 1e-5)
  117. return {greedy, greedy};
  118. // Stochastic sample (only if temperature > 0)
  119. // Sample with weights w_i = exp(log(p_i)/T), and return original probs[idx].
  120. std::vector<double> logw(probs.size(), -std::numeric_limits<double>::infinity());
  121. double max_logw = -std::numeric_limits<double>::infinity();
  122. bool valid = false;
  123. for(std::size_t i = 0; i < probs.size(); ++i)
  124. {
  125. float p = probs[i];
  126. if(p < 0.0f)
  127. throwf("sample: negative probabilities");
  128. if(p > 0.0f)
  129. {
  130. valid = true;
  131. double lw = std::log(p) / temperature;
  132. logw[i] = lw;
  133. max_logw = std::max(lw, max_logw);
  134. }
  135. }
  136. if(!valid)
  137. throwf("sample: all probabilities are 0");
  138. std::vector<double> weights(probs.size(), 0.0);
  139. double wsum = 0.0;
  140. for(std::size_t i = 0; i < probs.size(); ++i)
  141. {
  142. if(std::isfinite(logw[i]))
  143. {
  144. // shift by max for numerical stability
  145. double wi = std::exp(logw[i] - max_logw);
  146. weights[i] = wi;
  147. wsum += wi;
  148. }
  149. }
  150. if(wsum <= 0.0)
  151. throwf("sample: negative weight sum: %f", wsum);
  152. std::discrete_distribution<int> dist(weights.begin(), weights.end());
  153. int idx = dist(rng);
  154. sample.index = idx;
  155. sample.prob = probs[idx];
  156. sample.confidence = weights[idx] / wsum;
  157. return {sample, greedy};
  158. }
  159. struct ScopedTimer
  160. {
  161. std::string name;
  162. std::chrono::steady_clock::time_point t0;
  163. explicit ScopedTimer(const std::string & n) : name(n), t0(std::chrono::steady_clock::now()) {}
  164. ScopedTimer(const ScopedTimer &) = delete;
  165. ScopedTimer & operator=(const ScopedTimer &) = delete;
  166. ScopedTimer(ScopedTimer &&) = delete;
  167. ScopedTimer & operator=(ScopedTimer &&) = delete;
  168. ~ScopedTimer()
  169. {
  170. auto dt = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - t0).count();
  171. logAi->info("%s: %lld ms", name, dt);
  172. }
  173. };
  174. }
  175. std::unique_ptr<Ort::Session> NNModelStochastic::loadModel(const std::string & path, const Ort::SessionOptions & opts)
  176. {
  177. static const auto env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "vcmi"};
  178. const auto rpath = ResourcePath(path, EResType::AI_MODEL);
  179. const auto * rhandler = CResourceHandler::get();
  180. if(!rhandler->existsResource(rpath))
  181. throwf("resource does not exist: %s", rpath.getName());
  182. const auto & [data, length] = rhandler->load(rpath)->readAll();
  183. return std::make_unique<Ort::Session>(env, data.get(), length, opts);
  184. }
  185. int NNModelStochastic::readVersion(const Ort::ModelMetadata & md) const
  186. {
  187. /*
  188. * version
  189. * dtype=int
  190. * shape=scalar
  191. *
  192. * Version of the model (current implementation is at version 13).
  193. * If needed, NNModel may be extended to support other versions as well.
  194. *
  195. */
  196. int res = -1;
  197. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("version", allocator);
  198. if(!v)
  199. throwf("readVersion: no such key");
  200. std::string vs(v.get());
  201. try
  202. {
  203. res = std::stoi(vs);
  204. }
  205. catch(...)
  206. {
  207. throwf("readVersion: not an int: %s", vs);
  208. }
  209. if(res != 13)
  210. throwf("readVersion: want: 13, have: %d (%s)", res, vs);
  211. return res;
  212. }
  213. Schema::Side NNModelStochastic::readSide(const Ort::ModelMetadata & md) const
  214. {
  215. /*
  216. * side
  217. * dtype=int
  218. * shape=scalar
  219. *
  220. * Battlefield side the model was trained on (see Schema::Side enum).
  221. *
  222. */
  223. Schema::Side res;
  224. Ort::AllocatedStringPtr v = md.LookupCustomMetadataMapAllocated("side", allocator);
  225. if(!v)
  226. throw std::runtime_error("metadata error: side: no such key");
  227. std::string vs(v.get());
  228. try
  229. {
  230. res = static_cast<Schema::Side>(std::stoi(vs));
  231. }
  232. catch(...)
  233. {
  234. throw std::runtime_error("metadata error: side: not an int");
  235. }
  236. return res;
  237. }
  238. Vec3D<int32_t> NNModelStochastic::readActionTable(const Ort::ModelMetadata & md) const
  239. {
  240. /*
  241. * action_table
  242. * dtype=int
  243. * shape=[4, 165, 165]:
  244. * d1: action (WAIT, MOVE, AMOVE, SHOOT)
  245. * d2: target hex for MOVE, AMOVE (hex to move to) or SHOOT
  246. * d3: target hex for AMOVE (hex to melee-attack at after moving)
  247. *
  248. */
  249. Vec3D<int32_t> res = {};
  250. Ort::AllocatedStringPtr ab = md.LookupCustomMetadataMapAllocated("action_table", allocator);
  251. if(!ab)
  252. throwf("readActionTable: metadata key 'action_table' missing");
  253. const std::string jsonstr(ab.get());
  254. try
  255. {
  256. auto jn = JsonNode(jsonstr.data(), jsonstr.size(), "<ONNX metadata: all_sizes>");
  257. for(auto & jv0 : jn.Vector())
  258. {
  259. auto vec1 = std::vector<std::vector<int32_t>>{};
  260. for(auto & jv1 : jv0.Vector())
  261. {
  262. auto vec2 = std::vector<int32_t>{};
  263. for(auto & jv2 : jv1.Vector())
  264. {
  265. if(!jv2.isNumber())
  266. {
  267. throwf("invalid data type: want: %d, got: %d", EI(JsonNode::JsonType::DATA_INTEGER), EI(jv2.getType()));
  268. }
  269. vec2.push_back(static_cast<int32_t>(jv2.Integer()));
  270. }
  271. vec1.emplace_back(vec2);
  272. }
  273. res.emplace_back(vec1);
  274. }
  275. }
  276. catch(const std::exception & e)
  277. {
  278. throwf(std::string("failed to parse 'action_table' JSON: ") + e.what());
  279. }
  280. if(res.size() != 4)
  281. throwf("readActionTable: bad size for d1: want: 4, have: %zu", res.size());
  282. if(res[0].size() != 165)
  283. throwf("readActionTable: bad size for d2: want: 165, have: %zu", res[0].size());
  284. if(res[0][0].size() != 165)
  285. throwf("readActionTable: bad size for d3: want: 165, have: %zu", res[0][0].size());
  286. return res;
  287. }
  288. std::vector<const char *> NNModelStochastic::readInputNames()
  289. {
  290. /*
  291. * Model inputs (4):
  292. * [0] battlefield state
  293. * dtype=float
  294. * shape=[S] where S=Schema::V13::BATTLEFIELD_STATE_SIZE
  295. * [1] edge index
  296. * dtype=int32
  297. * shape=[2, E*] where E is the number of edges
  298. * [2] edge attributes
  299. * dtype=float
  300. * shape=[E*, 1]
  301. * [3] lengths
  302. * dtype=int
  303. * shape=[LT_COUNT]
  304. */
  305. std::vector<const char *> res;
  306. auto count = model->GetInputCount();
  307. if(count != 4)
  308. throwf("wrong input count: want: %d, have: %lld", 4, count);
  309. inputNamePtrs.reserve(count);
  310. res.reserve(count);
  311. for(size_t i = 0; i < count; ++i)
  312. {
  313. inputNamePtrs.emplace_back(model->GetInputNameAllocated(i, allocator));
  314. res.push_back(inputNamePtrs.back().get());
  315. }
  316. return res;
  317. }
  318. std::vector<const char *> NNModelStochastic::readOutputNames()
  319. {
  320. /*
  321. * Model outputs (6):
  322. * [0] main action probabilities (see readActionTable, d0)
  323. * dtype=float
  324. * shape=[4]
  325. * [1] hex#1 probabilities (see readActionTable, d1)
  326. * dtype=float
  327. * shape=[4, 165]
  328. * [2] hex#2 probabilities (see readActionTable, d2)
  329. * dtype=float
  330. * shape=[165, 165]
  331. * [3] main action mask
  332. * dtype=int
  333. * shape=[4]
  334. * [4] hex#1 mask
  335. * dtype=int
  336. * shape=[4, 165]
  337. * [5] hex#2 mask
  338. * dtype=int
  339. * shape=[165, 165]
  340. */
  341. std::vector<const char *> res;
  342. auto count = model->GetOutputCount();
  343. if(count != 6)
  344. throwf("wrong output count: want: %d, have: %lld", 6, count);
  345. outputNamePtrs.reserve(count);
  346. res.reserve(count);
  347. for(size_t i = 0; i < count; ++i)
  348. {
  349. outputNamePtrs.emplace_back(model->GetOutputNameAllocated(i, allocator));
  350. res.push_back(outputNamePtrs.back().get());
  351. }
  352. return res;
  353. }
  354. NNModelStochastic::NNModelStochastic(const std::string & path, float temperature, uint64_t seed)
  355. : path(path), temperature(temperature), meminfo(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault))
  356. {
  357. logAi->info("MMAI: NNModel params: seed=%1%, temperature=%2%, model=%3%", seed, temperature, path);
  358. rng = std::mt19937(seed);
  359. /*
  360. * IMPORTANT:
  361. * There seems to be an UB in the model unless either (or both):
  362. * a) DisableMemPattern
  363. * b) GraphOptimizationLevel::ORT_DISABLE_ALL
  364. *
  365. * Mem pattern does not impact performance => disable.
  366. * Graph optimization causes < 30% speedup => not worth the risk, disable.
  367. *
  368. */
  369. auto opts = Ort::SessionOptions();
  370. opts.DisableMemPattern();
  371. opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
  372. opts.SetExecutionMode(ORT_SEQUENTIAL); // ORT_SEQUENTIAL = no inter-op parallelism
  373. opts.SetInterOpNumThreads(1); // Inter-op threads matter in ORT_PARALLEL
  374. opts.SetIntraOpNumThreads(4); // Parallelism inside kernels/operators
  375. model = loadModel(path, opts);
  376. auto md = model->GetModelMetadata();
  377. version = readVersion(md);
  378. side = readSide(md);
  379. actionTable = readActionTable(md);
  380. inputNames = readInputNames();
  381. outputNames = readOutputNames();
  382. logAi->info("MMAI: version %d initialized on side=%d (stochastic=1)", version, EI(side));
  383. }
  384. Schema::ModelType NNModelStochastic::getType()
  385. {
  386. return Schema::ModelType::NN;
  387. };
  388. std::string NNModelStochastic::getName()
  389. {
  390. return "MMAI_MODEL";
  391. };
  392. int NNModelStochastic::getVersion()
  393. {
  394. return version;
  395. };
  396. Schema::Side NNModelStochastic::getSide()
  397. {
  398. return side;
  399. };
  400. int NNModelStochastic::getAction(const MMAI::Schema::IState * s)
  401. {
  402. auto timer = ScopedTimer("getAction");
  403. auto any = s->getSupplementaryData();
  404. if(s->version() != version)
  405. throwf("getAction: unsupported IState version: want: %d, have: %d", version, s->version());
  406. if(!any.has_value())
  407. throw std::runtime_error("extractSupplementaryData: supdata is empty");
  408. auto err = MMAI::Schema::AnyCastError(any, typeid(const MMAI::Schema::V13::ISupplementaryData *));
  409. if(!err.empty())
  410. throwf("getAction: anycast failed: %s", err);
  411. const auto * sup = std::any_cast<const MMAI::Schema::V13::ISupplementaryData *>(any);
  412. if(sup->getIsBattleEnded())
  413. {
  414. timer.name = boost::str(boost::format("MMAI action: %d (battle ended)") % MMAI::Schema::ACTION_RESET);
  415. return MMAI::Schema::ACTION_RESET;
  416. }
  417. auto inputs = prepareInputsV13(s, sup);
  418. auto outputs = model->Run(Ort::RunOptions(), inputNames.data(), inputs.data(), inputs.size(), outputNames.data(), outputNames.size());
  419. if(outputs.size() != 6)
  420. throwf("getAction: bad output size: want: 6, have: %d", outputs.size());
  421. const auto act0_probs = toVec1D<float>("act0_probs", outputs[0], 4); // WAIT, MOVE, AMOVE, SHOOT
  422. const auto hex1_probs = toVec2D<float>("hex1_probs", outputs[1], {4, 165});
  423. const auto hex2_probs = toVec2D<float>("hex2_probs", outputs[2], {165, 165});
  424. const auto act0_mask = toVec1D<int>("act0_mask", outputs[3], 4); // WAIT, MOVE, AMOVE, SHOOT
  425. const auto hex1_mask = toVec2D<int>("hex1_mask", outputs[4], {4, 165});
  426. const auto hex2_mask = toVec2D<int>("hex2_mask", outputs[5], {165, 165});
  427. const auto [act0_sample, act0_greedy] = categorical(act0_probs, temperature, rng);
  428. const auto [hex1_sample, hex1_greedy] = categorical(hex1_probs.at(act0_sample.index), temperature, rng);
  429. const auto [hex2_sample, hex2_greedy] = categorical(hex2_probs.at(hex1_sample.index), temperature, rng);
  430. if(act0_sample.prob == 0)
  431. throwf("getAction: act0_sample has 0 probability");
  432. else if(act0_mask.at(act0_sample.index) == 0)
  433. throwf("getAction: act0_sample is masked out");
  434. // Hex1 is always needed if act0 != 0 (WAIT)
  435. if(act0_sample.index > 0)
  436. {
  437. if(hex1_sample.prob == 0)
  438. throwf("getAction: hex1_sample has 0 probability");
  439. else if(hex1_mask.at(act0_sample.index).at(hex1_sample.index) == 0)
  440. throwf("getAction: hex1_sample is masked out");
  441. }
  442. // Hex2 is only needed if act0 == 2 (AMOVE)
  443. if(act0_sample.index == 2)
  444. {
  445. if(hex2_sample.prob == 0)
  446. throwf("getAction: hex2_sample has 0 probability");
  447. else if(hex2_mask.at(hex1_sample.index).at(hex2_sample.index) == 0)
  448. throwf("getAction: hex2_sample is masked out");
  449. }
  450. const auto & saction = actionTable.at(act0_sample.index).at(hex1_sample.index).at(hex2_sample.index);
  451. const auto & gaction = actionTable.at(act0_greedy.index).at(hex1_greedy.index).at(hex2_greedy.index);
  452. const auto & mask = s->getActionMask();
  453. if(!mask->at(saction))
  454. throwf("getAction: sampled action is masked"); // Incorrect mask?
  455. auto sconf = act0_sample.confidence * hex1_sample.confidence * hex2_sample.confidence;
  456. auto sprob = act0_sample.prob * hex1_sample.prob * hex2_sample.prob;
  457. auto gconf = act0_greedy.confidence * hex1_greedy.confidence * hex2_greedy.confidence;
  458. auto gprob = act0_greedy.prob * hex1_greedy.prob * hex2_greedy.prob;
  459. auto fmt = boost::format("%s: %d (prob=%.2f conf=%.2f). Detail: [%d %d %d] (prob=[%.2f %.2f %.2f] conf=[%.2f %.2f %.2f])");
  460. logAi->debug(
  461. boost::str(
  462. fmt % "MMAI (greedy)" % gaction % gprob % gconf % act0_greedy.index % hex1_greedy.index % hex2_greedy.index % act0_greedy.prob % hex1_greedy.prob
  463. % hex2_greedy.prob % act0_greedy.confidence % hex1_greedy.confidence % hex2_greedy.confidence
  464. )
  465. );
  466. logAi->debug(
  467. boost::str(
  468. fmt % "MMAI (sample)" % saction % sprob % sconf % act0_sample.index % hex1_sample.index % hex2_sample.index % act0_sample.prob % hex1_sample.prob
  469. % hex2_sample.prob % act0_sample.confidence % hex1_sample.confidence % hex2_sample.confidence
  470. )
  471. );
  472. timer.name = boost::str(boost::format("MMAI action: %d (confidence=%.2f)") % saction % sconf);
  473. return saction;
  474. };
  475. double NNModelStochastic::getValue(const MMAI::Schema::IState * s)
  476. {
  477. // This quantifies how good is the current state as perceived by the model
  478. // (not used, not implemented)
  479. return 0;
  480. }
  481. std::vector<Ort::Value> NNModelStochastic::prepareInputsV13(const MMAI::Schema::IState * s, const MMAI::Schema::V13::ISupplementaryData * sup)
  482. {
  483. auto lengths = std::vector<int>{};
  484. lengths.reserve(LT_COUNT);
  485. auto ei_flat_src = std::vector<int>{};
  486. auto ei_flat_dst = std::vector<int>{};
  487. auto ea_flat = std::vector<float>{};
  488. std::ostringstream oss;
  489. int i = 0;
  490. for(const auto & [type, links] : sup->getAllLinks())
  491. {
  492. // assert order
  493. if(EI(type) != i)
  494. throwf("unexpected link type: want: %d, have: %d", i, EI(type));
  495. const auto & srcinds = links->getSrcIndex();
  496. const auto & dstinds = links->getDstIndex();
  497. const auto & attrs = links->getAttributes();
  498. const auto nlinks = srcinds.size();
  499. if(dstinds.size() != nlinks)
  500. throwf("unexpected dstinds.size() for LinkType(%d): want: %d, have: %d", EI(type), nlinks, dstinds.size());
  501. if(attrs.size() != nlinks)
  502. throwf("unexpected attrs.size() for LinkType(%d): want: %d, have: %d", EI(type), nlinks, attrs.size());
  503. oss << nlinks << " ";
  504. lengths.push_back(static_cast<int>(nlinks));
  505. ei_flat_src.insert(ei_flat_src.end(), srcinds.begin(), srcinds.end());
  506. ei_flat_dst.insert(ei_flat_dst.end(), dstinds.begin(), dstinds.end());
  507. ea_flat.insert(ea_flat.end(), attrs.begin(), attrs.end());
  508. ++i;
  509. }
  510. if(i != LT_COUNT)
  511. throwf("unexpected links count: want: %d, have: %d", LT_COUNT, i);
  512. auto sum_e = ei_flat_src.size();
  513. auto ei_flat = std::vector<int64_t>{};
  514. ei_flat.reserve(2 * sum_e);
  515. ei_flat.insert(ei_flat.end(), ei_flat_src.begin(), ei_flat_src.end());
  516. ei_flat.insert(ei_flat.end(), ei_flat_dst.begin(), ei_flat_dst.end());
  517. const auto * state = s->getBattlefieldState();
  518. auto estate = std::vector<float>(state->size());
  519. std::ranges::copy(*state, estate.begin());
  520. auto tensors = std::vector<Ort::Value>{};
  521. tensors.push_back(toTensor("obs", estate, {static_cast<int64_t>(estate.size())}));
  522. tensors.push_back(toTensor("ei_flat", ei_flat, {2, static_cast<int64_t>(sum_e)}));
  523. tensors.push_back(toTensor("ea_flat", ea_flat, {static_cast<int64_t>(sum_e), 1}));
  524. tensors.push_back(toTensor("lengths", lengths, {LT_COUNT}));
  525. logAi->debug("NNModel: Edge lengths: [ " + oss.str() + "]");
  526. logAi->debug("NNModel: Input shapes: state={%d} edgeIndex={2, %d} edgeAttrs={%d, 1}", estate.size(), sum_e, sum_e);
  527. return tensors;
  528. }
  529. template<typename T>
  530. Ort::Value NNModelStochastic::toTensor(const std::string & name, std::vector<T> & vec, const std::vector<int64_t> & shape)
  531. {
  532. // Sanity check
  533. int64_t numel = 1;
  534. for(int64_t d : shape)
  535. numel *= d;
  536. if(numel != vec.size())
  537. throwf("toTensor: %s: numel check failed: want: %d, have: %d", name, numel, vec.size());
  538. // Create a memory-owning tensor then copy data
  539. auto res = Ort::Value::CreateTensor<T>(allocator, shape.data(), shape.size());
  540. T * dst = res.template GetTensorMutableData<T>();
  541. std::memcpy(dst, vec.data(), vec.size() * sizeof(T));
  542. return res;
  543. }
  544. } // namespace MMAI::BAI