sampling.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /*
  2. * sampling.h, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #pragma once
  11. #include <onnxruntime_cxx_api.h>
  12. #include "BAI/model/util/common.h"
  13. namespace MMAI::BAI::sampling
  14. {
  15. struct SampleResult
  16. {
  17. int index;
  18. double prob;
  19. bool fallback;
  20. };
  21. struct TripletSample
  22. {
  23. int act0;
  24. int hex1;
  25. int hex2;
  26. double confidence;
  27. };
  28. std::vector<int64_t> shape_of(const Ort::Value & v);
  29. template<typename T>
  30. std::vector<T> to_vector(const Ort::Value & v);
  31. std::vector<double> softmax(const std::vector<double> & logits);
  32. int argmax(const std::vector<double> & xs);
  33. int count_valid(const std::vector<int32_t> & mask_1d);
  34. std::vector<double> make_masked_logits(const std::vector<float> & logits_1d, const std::vector<int32_t> & mask_1d);
  35. SampleResult sample_uniform_over_mask(const std::vector<int32_t> & mask_1d, int n_valid, std::mt19937 & rng);
  36. SampleResult sample_softmax_over_mask(const std::vector<double> & masked_logits, const std::vector<int32_t> & mask_1d, double temperature, std::mt19937 & rng);
  37. // Masked categorical sampling given a logits vector
  38. SampleResult
  39. sample_masked_logits(const std::vector<float> & logits_1d, const std::vector<int32_t> & mask_1d, bool throw_if_empty, double temperature, std::mt19937 & rng);
  40. //
  41. // Samples a {action, hex1, hex2} triplet given output logits and masks
  42. //
  43. // Expected shapes:
  44. // act0_logits: [1, 4] float32
  45. // hex1_logits: [1, 165] float32
  46. // hex2_logits: [1, 165] float32
  47. // mask_act0: [1, 4] int32
  48. // mask_hex1: [1, 4, 165] int32
  49. // mask_hex2: [1, 4, 165, 165] int32
  50. //
  51. TripletSample
  52. sample_triplet(const MaskedLogits & act0_logits, const MaskedLogits & hex1_logits, const MaskedLogits & hex2_logits, double temperature, std::mt19937 & rng);
  53. }