NNModelStochastic.h 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. /*
  2. * NNModelStochastic.h, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #pragma once
  11. #include <onnxruntime_cxx_api.h>
  12. #include "BAI/model/util/common.h"
  13. #include "schema/base.h"
  14. #include "schema/v13/types.h"
  15. namespace MMAI::BAI
  16. {
  17. template<class T>
  18. using Vec2D = std::vector<std::vector<T>>;
  19. template<class T>
  20. using Vec3D = std::vector<std::vector<std::vector<T>>>;
  21. class NNModelStochastic : public MMAI::Schema::IModel
  22. {
  23. public:
  24. explicit NNModelStochastic(const std::string & path, float _temperature, uint64_t seed);
  25. Schema::ModelType getType() override;
  26. std::string getName() override;
  27. int getVersion() override;
  28. Schema::Side getSide() override;
  29. int getAction(const MMAI::Schema::IState * s) override;
  30. double getValue(const MMAI::Schema::IState * s) override;
  31. private:
  32. std::string path;
  33. float temperature;
  34. std::string name;
  35. int version;
  36. Schema::Side side;
  37. std::mt19937 rng;
  38. Vec3D<int32_t> actionTable;
  39. // AllocatedStringPtrs manage the string lifetime
  40. // but names passed to model.Run must be const char*
  41. std::vector<Ort::AllocatedStringPtr> inputNamePtrs;
  42. std::vector<Ort::AllocatedStringPtr> outputNamePtrs;
  43. std::vector<const char *> inputNames;
  44. std::vector<const char *> outputNames;
  45. std::unique_ptr<Ort::Session> model = nullptr;
  46. Ort::AllocatorWithDefaultOptions allocator;
  47. Ort::MemoryInfo meminfo;
  48. std::vector<Ort::Value> prepareInputsV13(const MMAI::Schema::IState * state, const MMAI::Schema::V13::ISupplementaryData * sup);
  49. template<typename T>
  50. Ort::Value toTensor(const std::string & name, std::vector<T> & vec, const std::vector<int64_t> & shape);
  51. std::unique_ptr<Ort::Session> loadModel(const std::string & path, const Ort::SessionOptions & opts);
  52. int readVersion(const Ort::ModelMetadata & md) const;
  53. Schema::Side readSide(const Ort::ModelMetadata & md) const;
  54. Vec3D<int32_t> readActionTable(const Ort::ModelMetadata & md) const;
  55. std::vector<const char *> readInputNames();
  56. std::vector<const char *> readOutputNames();
  57. };
  58. }