LogicalExpression.h 18 KB


  1. /*
  2. * LogicalExpression.h, part of VCMI engine
  3. *
  4. * Authors: listed in file AUTHORS in main folder
  5. *
  6. * License: GNU General Public License v2.0 or later
  7. * Full text of license available in license.txt file, in main folder
  8. *
  9. */
  10. #pragma once
  11. //FIXME: move some of code into .cpp to avoid this include?
  12. #include "JsonNode.h"
  13. VCMI_LIB_NAMESPACE_BEGIN
  14. namespace LogicalExpressionDetail
  15. {
  16. /// class that defines required types for logical expressions
  17. template<typename ContainedClass>
  18. class ExpressionBase
  19. {
  20. public:
  21. /// Possible logical operations, mostly needed to create different types for std::variant
  22. enum EOperations
  23. {
  24. ANY_OF,
  25. ALL_OF,
  26. NONE_OF
  27. };
  28. template<EOperations tag> class Element;
  29. using OperatorAny = Element<ANY_OF>;
  30. using OperatorAll = Element<ALL_OF>;
  31. using OperatorNone = Element<NONE_OF>;
  32. using Value = ContainedClass;
  33. /// Variant that contains all possible elements from logical expression
  34. using Variant = std::variant<OperatorAll, OperatorAny, OperatorNone, Value>;
  35. /// Variant element, contains list of expressions to which operation "tag" should be applied
  36. template<EOperations tag>
  37. class Element
  38. {
  39. public:
  40. Element() {}
  41. Element(std::vector<Variant> expressions):
  42. expressions(expressions)
  43. {}
  44. std::vector<Variant> expressions;
  45. bool operator == (const Element & other) const
  46. {
  47. return expressions == other.expressions;
  48. }
  49. template <typename Handler>
  50. void serialize(Handler & h, const int version)
  51. {
  52. h & expressions;
  53. }
  54. };
  55. };
  56. /// Visitor to test result (true/false) of the expression
  57. template<typename ContainedClass>
  58. class TestVisitor
  59. {
  60. using Base = ExpressionBase<ContainedClass>;
  61. std::function<bool(const typename Base::Value &)> classTest;
  62. size_t countPassed(const std::vector<typename Base::Variant> & element) const
  63. {
  64. return boost::range::count_if(element, [&](const typename Base::Variant & expr)
  65. {
  66. return std::visit(*this, expr);
  67. });
  68. }
  69. public:
  70. TestVisitor(std::function<bool (const typename Base::Value &)> classTest):
  71. classTest(classTest)
  72. {}
  73. bool operator()(const typename Base::OperatorAny & element) const
  74. {
  75. return countPassed(element.expressions) != 0;
  76. }
  77. bool operator()(const typename Base::OperatorAll & element) const
  78. {
  79. return countPassed(element.expressions) == element.expressions.size();
  80. }
  81. bool operator()(const typename Base::OperatorNone & element) const
  82. {
  83. return countPassed(element.expressions) == 0;
  84. }
  85. bool operator()(const typename Base::Value & element) const
  86. {
  87. return classTest(element);
  88. }
  89. };
  90. template <typename ContainedClass>
  91. class SatisfiabilityVisitor;
  92. template <typename ContainedClass>
  93. class FalsifiabilityVisitor;
  94. template<typename ContainedClass>
  95. class PossibilityVisitor
  96. {
  97. using Base = ExpressionBase<ContainedClass>;
  98. protected:
  99. std::function<bool(const typename Base::Value &)> satisfiabilityTest;
  100. std::function<bool(const typename Base::Value &)> falsifiabilityTest;
  101. SatisfiabilityVisitor<ContainedClass> *satisfiabilityVisitor;
  102. FalsifiabilityVisitor<ContainedClass> *falsifiabilityVisitor;
  103. size_t countSatisfiable(const std::vector<typename Base::Variant> & element) const
  104. {
  105. return boost::range::count_if(element, [&](const typename Base::Variant & expr)
  106. {
  107. return std::visit(*satisfiabilityVisitor, expr);
  108. });
  109. }
  110. size_t countFalsifiable(const std::vector<typename Base::Variant> & element) const
  111. {
  112. return boost::range::count_if(element, [&](const typename Base::Variant & expr)
  113. {
  114. return std::visit(*falsifiabilityVisitor, expr);
  115. });
  116. }
  117. public:
  118. PossibilityVisitor(std::function<bool (const typename Base::Value &)> satisfiabilityTest,
  119. std::function<bool (const typename Base::Value &)> falsifiabilityTest):
  120. satisfiabilityTest(satisfiabilityTest),
  121. falsifiabilityTest(falsifiabilityTest),
  122. satisfiabilityVisitor(nullptr),
  123. falsifiabilityVisitor(nullptr)
  124. {}
  125. void setSatisfiabilityVisitor(SatisfiabilityVisitor<ContainedClass> *satisfiabilityVisitor)
  126. {
  127. this->satisfiabilityVisitor = satisfiabilityVisitor;
  128. }
  129. void setFalsifiabilityVisitor(FalsifiabilityVisitor<ContainedClass> *falsifiabilityVisitor)
  130. {
  131. this->falsifiabilityVisitor = falsifiabilityVisitor;
  132. }
  133. };
  134. /// Visitor to test whether expression's value can be true
  135. template <typename ContainedClass>
  136. class SatisfiabilityVisitor : public PossibilityVisitor<ContainedClass>
  137. {
  138. using Base = ExpressionBase<ContainedClass>;
  139. public:
  140. SatisfiabilityVisitor(std::function<bool (const typename Base::Value &)> satisfiabilityTest,
  141. std::function<bool (const typename Base::Value &)> falsifiabilityTest):
  142. PossibilityVisitor<ContainedClass>(satisfiabilityTest, falsifiabilityTest)
  143. {
  144. this->setSatisfiabilityVisitor(this);
  145. }
  146. bool operator()(const typename Base::OperatorAny & element) const
  147. {
  148. return this->countSatisfiable(element.expressions) != 0;
  149. }
  150. bool operator()(const typename Base::OperatorAll & element) const
  151. {
  152. return this->countSatisfiable(element.expressions) == element.expressions.size();
  153. }
  154. bool operator()(const typename Base::OperatorNone & element) const
  155. {
  156. return this->countFalsifiable(element.expressions) == element.expressions.size();
  157. }
  158. bool operator()(const typename Base::Value & element) const
  159. {
  160. return this->satisfiabilityTest(element);
  161. }
  162. };
  163. /// Visitor to test whether expression's value can be false
  164. template <typename ContainedClass>
  165. class FalsifiabilityVisitor : public PossibilityVisitor<ContainedClass>
  166. {
  167. using Base = ExpressionBase<ContainedClass>;
  168. public:
  169. FalsifiabilityVisitor(std::function<bool (const typename Base::Value &)> satisfiabilityTest,
  170. std::function<bool (const typename Base::Value &)> falsifiabilityTest):
  171. PossibilityVisitor<ContainedClass>(satisfiabilityTest, falsifiabilityTest)
  172. {
  173. this->setFalsifiabilityVisitor(this);
  174. }
  175. bool operator()(const typename Base::OperatorAny & element) const
  176. {
  177. return this->countFalsifiable(element.expressions) == element.expressions.size();
  178. }
  179. bool operator()(const typename Base::OperatorAll & element) const
  180. {
  181. return this->countFalsifiable(element.expressions) != 0;
  182. }
  183. bool operator()(const typename Base::OperatorNone & element) const
  184. {
  185. return this->countSatisfiable(element.expressions) != 0;
  186. }
  187. bool operator()(const typename Base::Value & element) const
  188. {
  189. return this->falsifiabilityTest(element);
  190. }
  191. };
  192. /// visitor that is trying to generates candidates that must be fulfilled
  193. /// to complete this expression
  194. template<typename ContainedClass>
  195. class CandidatesVisitor
  196. {
  197. using Base = ExpressionBase<ContainedClass>;
  198. using TValueList = std::vector<typename Base::Value>;
  199. TestVisitor<ContainedClass> classTest;
  200. public:
  201. CandidatesVisitor(std::function<bool(const typename Base::Value &)> classTest):
  202. classTest(classTest)
  203. {}
  204. TValueList operator()(const typename Base::OperatorAny & element) const
  205. {
  206. TValueList ret;
  207. if (!classTest(element))
  208. {
  209. for (auto & elem : element.expressions)
  210. boost::range::copy(std::visit(*this, elem), std::back_inserter(ret));
  211. }
  212. return ret;
  213. }
  214. TValueList operator()(const typename Base::OperatorAll & element) const
  215. {
  216. TValueList ret;
  217. if (!classTest(element))
  218. {
  219. for (auto & elem : element.expressions)
  220. boost::range::copy(std::visit(*this, elem), std::back_inserter(ret));
  221. }
  222. return ret;
  223. }
  224. TValueList operator()(const typename Base::OperatorNone & element) const
  225. {
  226. return TValueList(); //TODO. Implementing this one is not straightforward, if ever possible
  227. }
  228. TValueList operator()(const typename Base::Value & element) const
  229. {
  230. if (classTest(element))
  231. return TValueList();
  232. else
  233. return TValueList(1, element);
  234. }
  235. };
  236. /// Simple foreach visitor
  237. template<typename ContainedClass>
  238. class ForEachVisitor
  239. {
  240. using Base = ExpressionBase<ContainedClass>;
  241. std::function<typename Base::Variant(const typename Base::Value &)> visitor;
  242. public:
  243. ForEachVisitor(std::function<typename Base::Variant(const typename Base::Value &)> visitor):
  244. visitor(visitor)
  245. {}
  246. typename Base::Variant operator()(const typename Base::Value & element) const
  247. {
  248. return visitor(element);
  249. }
  250. template <typename Type>
  251. typename Base::Variant operator()(Type element) const
  252. {
  253. for (auto & entry : element.expressions)
  254. entry = std::visit(*this, entry);
  255. return element;
  256. }
  257. };
  258. /// Minimizing visitor that removes all redundant elements from variant (e.g. AllOf inside another AllOf can be merged safely)
  259. template<typename ContainedClass>
  260. class MinimizingVisitor
  261. {
  262. using Base = ExpressionBase<ContainedClass>;
  263. public:
  264. typename Base::Variant operator()(const typename Base::Value & element) const
  265. {
  266. return element;
  267. }
  268. template <typename Type>
  269. typename Base::Variant operator()(const Type & element) const
  270. {
  271. Type ret;
  272. for (auto & entryRO : element.expressions)
  273. {
  274. auto entry = std::visit(*this, entryRO);
  275. try
  276. {
  277. // copy entries from child of this type
  278. auto sublist = std::get<Type>(entry).expressions;
  279. std::move(sublist.begin(), sublist.end(), std::back_inserter(ret.expressions));
  280. }
  281. catch (std::bad_variant_access &)
  282. {
  283. // different type (e.g. allOf vs oneOf) just copy
  284. ret.expressions.push_back(entry);
  285. }
  286. }
  287. for ( auto it = ret.expressions.begin(); it != ret.expressions.end();)
  288. {
  289. if (std::find(ret.expressions.begin(), it, *it) != it)
  290. it = ret.expressions.erase(it); // erase duplicate
  291. else
  292. it++; // goto next
  293. }
  294. return ret;
  295. }
  296. };
  297. /// Json parser for expressions
  298. template <typename ContainedClass>
  299. class Reader
  300. {
  301. using Base = ExpressionBase<ContainedClass>;
  302. std::function<typename Base::Value(const JsonNode &)> classParser;
  303. typename Base::Variant readExpression(const JsonNode & node)
  304. {
  305. assert(!node.Vector().empty());
  306. std::string type = node.Vector()[0].String();
  307. if (type == "anyOf")
  308. return typename Base::OperatorAny(readVector(node));
  309. if (type == "allOf")
  310. return typename Base::OperatorAll(readVector(node));
  311. if (type == "noneOf")
  312. return typename Base::OperatorNone(readVector(node));
  313. return classParser(node);
  314. }
  315. std::vector<typename Base::Variant> readVector(const JsonNode & node)
  316. {
  317. std::vector<typename Base::Variant> ret;
  318. ret.reserve(node.Vector().size()-1);
  319. for (size_t i=1; i < node.Vector().size(); i++)
  320. ret.push_back(readExpression(node.Vector()[i]));
  321. return ret;
  322. }
  323. public:
  324. Reader(std::function<typename Base::Value(const JsonNode &)> classParser):
  325. classParser(classParser)
  326. {}
  327. typename Base::Variant operator ()(const JsonNode & node)
  328. {
  329. return readExpression(node);
  330. }
  331. };
  332. /// Serializes expression in JSON format. Part of map format.
  333. template<typename ContainedClass>
  334. class Writer
  335. {
  336. using Base = ExpressionBase<ContainedClass>;
  337. std::function<JsonNode(const typename Base::Value &)> classPrinter;
  338. JsonNode printExpressionList(std::string name, const std::vector<typename Base::Variant> & element) const
  339. {
  340. JsonNode ret;
  341. ret.Vector().resize(1);
  342. ret.Vector().back().String() = name;
  343. for (auto & expr : element)
  344. ret.Vector().push_back(std::visit(*this, expr));
  345. return ret;
  346. }
  347. public:
  348. Writer(std::function<JsonNode(const typename Base::Value &)> classPrinter):
  349. classPrinter(classPrinter)
  350. {}
  351. JsonNode operator()(const typename Base::OperatorAny & element) const
  352. {
  353. return printExpressionList("anyOf", element.expressions);
  354. }
  355. JsonNode operator()(const typename Base::OperatorAll & element) const
  356. {
  357. return printExpressionList("allOf", element.expressions);
  358. }
  359. JsonNode operator()(const typename Base::OperatorNone & element) const
  360. {
  361. return printExpressionList("noneOf", element.expressions);
  362. }
  363. JsonNode operator()(const typename Base::Value & element) const
  364. {
  365. return classPrinter(element);
  366. }
  367. };
  368. std::string DLL_LINKAGE getTextForOperator(const std::string & operation);
  369. /// Prints expression in human-readable format
  370. template<typename ContainedClass>
  371. class Printer
  372. {
  373. using Base = ExpressionBase<ContainedClass>;
  374. std::function<std::string(const typename Base::Value &)> classPrinter;
  375. std::unique_ptr<TestVisitor<ContainedClass>> statusTest;
  376. mutable std::string prefix;
  377. template<typename Operator>
  378. std::string formatString(std::string toFormat, const Operator & expr) const
  379. {
  380. // highlight not fulfilled expressions, if pretty formatting is on
  381. if (statusTest && !(*statusTest)(expr))
  382. return "{" + toFormat + "}";
  383. return toFormat;
  384. }
  385. std::string printExpressionList(const std::vector<typename Base::Variant> & element) const
  386. {
  387. std::string ret;
  388. prefix.push_back('\t');
  389. for (auto & expr : element)
  390. ret += prefix + std::visit(*this, expr) + "\n";
  391. prefix.pop_back();
  392. return ret;
  393. }
  394. public:
  395. Printer(std::function<std::string(const typename Base::Value &)> classPrinter):
  396. classPrinter(classPrinter)
  397. {}
  398. Printer(std::function<std::string(const typename Base::Value &)> classPrinter, std::function<bool(const typename Base::Value &)> toBool):
  399. classPrinter(classPrinter),
  400. statusTest(new TestVisitor<ContainedClass>(toBool))
  401. {}
  402. std::string operator()(const typename Base::OperatorAny & element) const
  403. {
  404. return formatString(getTextForOperator("anyOf"), element) + "\n"
  405. + printExpressionList(element.expressions);
  406. }
  407. std::string operator()(const typename Base::OperatorAll & element) const
  408. {
  409. return formatString(getTextForOperator("allOf"), element) + "\n"
  410. + printExpressionList(element.expressions);
  411. }
  412. std::string operator()(const typename Base::OperatorNone & element) const
  413. {
  414. return formatString(getTextForOperator("noneOf"), element) + "\n"
  415. + printExpressionList(element.expressions);
  416. }
  417. std::string operator()(const typename Base::Value & element) const
  418. {
  419. return formatString(classPrinter(element), element);
  420. }
  421. };
  422. }
  423. ///
  424. /// Class for evaluation of logical expressions generated in runtime
  425. ///
  426. template<typename ContainedClass>
  427. class LogicalExpression
  428. {
  429. using Base = LogicalExpressionDetail::ExpressionBase<ContainedClass>;
  430. public:
  431. /// Type of values used in expressions, same as ContainedClass
  432. using Value = typename Base::Value;
  433. /// Operators for use in expressions, all include vectors with operands
  434. using OperatorAny = typename Base::OperatorAny;
  435. using OperatorAll = typename Base::OperatorAll;
  436. using OperatorNone = typename Base::OperatorNone;
  437. /// one expression entry
  438. using Variant = typename Base::Variant;
  439. private:
  440. Variant data;
  441. public:
  442. /// Base constructor
  443. LogicalExpression() = default;
  444. /// Constructor from variant or (implicitly) from Operator* types
  445. LogicalExpression(const Variant & data): data(data) {}
  446. /// Constructor that receives JsonNode as input and function that can parse Value instances
  447. LogicalExpression(const JsonNode & input, std::function<Value(const JsonNode &)> parser)
  448. {
  449. LogicalExpressionDetail::Reader<Value> reader(parser);
  450. LogicalExpression expr(reader(input));
  451. std::swap(data, expr.data);
  452. }
  453. Variant get() const
  454. {
  455. return data;
  456. }
  457. /// Simple visitor that visits all entries in expression
  458. Variant morph(std::function<Variant(const Value &)> morpher) const
  459. {
  460. LogicalExpressionDetail::ForEachVisitor<Value> visitor(morpher);
  461. return std::visit(visitor, data);
  462. }
  463. /// Minimizes expression, removing any redundant elements
  464. void minimize()
  465. {
  466. LogicalExpressionDetail::MinimizingVisitor<Value> visitor;
  467. data = std::visit(visitor, data);
  468. }
  469. /// calculates if expression evaluates to "true".
  470. /// Note: empty expressions always return true
  471. bool test(std::function<bool(const Value &)> toBool) const
  472. {
  473. LogicalExpressionDetail::TestVisitor<Value> testVisitor(toBool);
  474. return std::visit(testVisitor, data);
  475. }
  476. /// calculates if expression can evaluate to "true".
  477. bool satisfiable(std::function<bool(const Value &)> satisfiabilityTest, std::function<bool(const Value &)> falsifiabilityTest) const
  478. {
  479. LogicalExpressionDetail::SatisfiabilityVisitor<Value> satisfiabilityVisitor(satisfiabilityTest, falsifiabilityTest);
  480. LogicalExpressionDetail::FalsifiabilityVisitor<Value> falsifiabilityVisitor(satisfiabilityTest, falsifiabilityTest);
  481. satisfiabilityVisitor.setFalsifiabilityVisitor(&falsifiabilityVisitor);
  482. falsifiabilityVisitor.setSatisfiabilityVisitor(&satisfiabilityVisitor);
  483. return std::visit(satisfiabilityVisitor, data);
  484. }
  485. /// calculates if expression can evaluate to "false".
  486. bool falsifiable(std::function<bool(const Value &)> satisfiabilityTest, std::function<bool(const Value &)> falsifiabilityTest) const
  487. {
  488. LogicalExpressionDetail::SatisfiabilityVisitor<Value> satisfiabilityVisitor(satisfiabilityTest);
  489. LogicalExpressionDetail::FalsifiabilityVisitor<Value> falsifiabilityVisitor(falsifiabilityTest);
  490. satisfiabilityVisitor.setFalsifiabilityVisitor(&falsifiabilityVisitor);
  491. falsifiabilityVisitor.setFalsifiabilityVisitor(&satisfiabilityVisitor);
  492. return std::visit(falsifiabilityVisitor, data);
  493. }
  494. /// generates list of candidates that can be fulfilled by caller (like AI)
  495. std::vector<Value> getFulfillmentCandidates(std::function<bool(const Value &)> toBool) const
  496. {
  497. LogicalExpressionDetail::CandidatesVisitor<Value> candidateVisitor(toBool);
  498. return std::visit(candidateVisitor, data);
  499. }
  500. /// Converts expression in human-readable form
  501. /// Second version will try to do some pretty printing using H3 text formatting "{}"
  502. /// to indicate fulfilled components of an expression
  503. std::string toString(std::function<std::string(const Value &)> toStr) const
  504. {
  505. LogicalExpressionDetail::Printer<Value> printVisitor(toStr);
  506. return std::visit(printVisitor, data);
  507. }
  508. std::string toString(std::function<std::string(const Value &)> toStr, std::function<bool(const Value &)> toBool) const
  509. {
  510. LogicalExpressionDetail::Printer<Value> printVisitor(toStr, toBool);
  511. return std::visit(printVisitor, data);
  512. }
  513. JsonNode toJson(std::function<JsonNode(const Value &)> toJson) const
  514. {
  515. LogicalExpressionDetail::Writer<Value> writeVisitor(toJson);
  516. return std::visit(writeVisitor, data);
  517. }
  518. template <typename Handler>
  519. void serialize(Handler & h, const int version)
  520. {
  521. h & data;
  522. }
  523. };
  524. VCMI_LIB_NAMESPACE_END