NNModel.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. /*
  2. * NNModel.h, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #pragma once
  11. #include <onnxruntime_cxx_api.h>
  12. #include "BAI/model/util/common.h"
  13. #include "schema/base.h"
  14. #include "schema/v13/types.h"
  15. namespace MMAI::BAI
  16. {
  17. class NNModel : public MMAI::Schema::IModel
  18. {
  19. public:
  20. explicit NNModel(const std::string & path, float _temperature, uint64_t seed);
  21. Schema::ModelType getType() override;
  22. std::string getName() override;
  23. int getVersion() override;
  24. Schema::Side getSide() override;
  25. int getAction(const MMAI::Schema::IState * s) override;
  26. double getValue(const MMAI::Schema::IState * s) override;
  27. private:
  28. std::string path;
  29. float temperature;
  30. std::string name;
  31. int version;
  32. Schema::Side side;
  33. std::mt19937 rng;
  34. Vec3D<int32_t> actionTable;
  35. // AllocatedStringPtrs manage the string lifetime
  36. // but names passed to model.Run must be const char*
  37. std::vector<Ort::AllocatedStringPtr> inputNamePtrs;
  38. std::vector<Ort::AllocatedStringPtr> outputNamePtrs;
  39. Vec3D<int32_t> bucketSizes;
  40. bool isDynamic;
  41. std::vector<const char *> inputNames;
  42. std::vector<const char *> outputNames;
  43. std::unique_ptr<Ort::Session> model = nullptr;
  44. Ort::AllocatorWithDefaultOptions allocator;
  45. Ort::MemoryInfo meminfo;
  46. std::vector<Ort::Value> prepareInputsV13(const MMAI::Schema::IState * state, const MMAI::Schema::V13::ISupplementaryData * sup);
  47. template<typename T>
  48. Ort::Value toTensor(const std::string & name, std::vector<T> & vec, const std::vector<int64_t> & shape);
  49. std::unique_ptr<Ort::Session> loadModel(const std::string & path, const Ort::SessionOptions & opts);
  50. int readVersion(const Ort::ModelMetadata & md) const;
  51. Schema::Side readSide(const Ort::ModelMetadata & md) const;
  52. Vec3D<int32_t> readBucketSizes(const Ort::ModelMetadata & md) const;
  53. Vec3D<int32_t> readActionTable(const Ort::ModelMetadata & md) const;
  54. bool readIsDynamic(const Ort::ModelMetadata & md) const;
  55. std::vector<const char *> readInputNames(int want);
  56. std::vector<const char *> readOutputNames();
  57. };
  58. }