main.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. //#include "../global.h"
  2. #include "StdInc.h"
  3. #include "../lib/VCMI_Lib.h"
  4. namespace po = boost::program_options;
  5. namespace fs = boost::filesystem;
  6. using namespace std;
  7. #include <boost/random.hpp>
  8. //FANN
  9. #include <doublefann.h>
  10. #include <fann_cpp.h>
  11. std::string leftAI, rightAI, battle, results, logsDir;
  12. bool withVisualization = false;
  13. std::string servername;
  14. std::string runnername;
  15. extern DLL_EXPORT LibClasses * VLC;
  16. typedef std::map<int, CArtifactInstance*> TArtSet;
  17. namespace Utilities
  18. {
  19. std::string addQuotesIfNeeded(const std::string &s)
  20. {
  21. if(s.find_first_of(' ') != std::string::npos)
  22. return "\"" + s + "\"";
  23. return s;
  24. }
  25. void prog_help()
  26. {
  27. std::cout << "If run without args, then StupidAI will be run on b1.json.\n";
  28. }
  29. string toString(int i)
  30. {
  31. return boost::lexical_cast<string>(i);
  32. }
  33. string describeBonus(const Bonus &b)
  34. {
  35. return "+" + toString(b.val) + "_to_" + bonusTypeToString(b.type)+"_sub"+toString(b.subtype);
  36. }
  37. }
  38. using namespace Utilities;
  39. struct Example
  40. {
  41. //ANN input
  42. DuelParameters dp;
  43. CArtifactInstance *art;
  44. //ANN expected output
  45. double value;
  46. //other
  47. std::string description;
  48. int i, j, k; //helper values for identification
  49. Example(){}
  50. Example(const DuelParameters &Dp, CArtifactInstance *Art, double Value)
  51. : dp(Dp), art(Art), value(Value)
  52. {}
  53. inline bool operator<(const Example & rhs) const
  54. {
  55. if (k<rhs.k)
  56. return true;
  57. if (k>rhs.k)
  58. return false;
  59. if (j<rhs.j)
  60. return true;
  61. if (j>rhs.j)
  62. return false;
  63. if (i<rhs.i)
  64. return true;
  65. if (i>rhs.i)
  66. return false;
  67. return false;
  68. }
  69. bool operator==(const Example &rhs) const
  70. {
  71. return rhs.i == i && rhs.j == j && rhs.k == k;
  72. }
  73. template <typename Handler> void serialize(Handler &h, const int version)
  74. {
  75. h & dp & art & value & description & i & j & k;
  76. }
  77. };
  78. class Framework
  79. {
  80. static CArtifactInstance *generateArtWithBonus(const Bonus &b);
  81. static DuelParameters generateDuel(const ArmyDescriptor &ad); //generates simple duel where both sides have given army
  82. static void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "");
  83. static double playBattle(const DuelParameters &dp);
  84. static double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR);
  85. static double rateArt(const DuelParameters dp, CArtifactInstance * inst); //rates given artifact
  86. static int theLastN();
  87. static vector<string> getFileNames(const string &dirname = "./examples/", const std::string &ext = "example");
  88. static vector<ArmyDescriptor> learningArmies();
  89. static vector<Bonus> learningBonuses();
  90. public:
  91. Framework();
  92. ~Framework();
  93. static void buildLearningSet();
  94. static vector<Example> loadExamples(bool printInfo = true);
  95. };
  96. vector<string> Framework::getFileNames(const string &dirname, const std::string &ext)
  97. {
  98. vector<string> ret;
  99. if(!fs::exists(dirname))
  100. {
  101. tlog1 << "Cannot find " << dirname << " directory! Will attempt creating it.\n";
  102. fs::create_directory(dirname);
  103. }
  104. fs::path tie(dirname);
  105. fs::directory_iterator end_iter;
  106. for ( fs::directory_iterator file (tie); file!=end_iter; ++file )
  107. {
  108. if(fs::is_regular_file(file->status())
  109. && boost::ends_with(file->path().filename(), ext))
  110. {
  111. ret.push_back(file->path().string());
  112. }
  113. }
  114. return ret;
  115. }
  116. vector<Example> Framework::loadExamples(bool printInfo)
  117. {
  118. std::vector<Example> examples;
  119. BOOST_FOREACH(auto fname, getFileNames("./examples/", "example"))
  120. {
  121. CLoadFile loadf(fname);
  122. Example ex;
  123. loadf >> ex;
  124. examples.push_back(ex);
  125. }
  126. if(printInfo)
  127. {
  128. tlog0 << "Found " << examples.size() << " examples.\n";
  129. BOOST_FOREACH(auto &ex, examples)
  130. {
  131. tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
  132. }
  133. }
  134. return examples;
  135. }
  136. int Framework::theLastN()
  137. {
  138. auto fnames = getFileNames();
  139. if(!fnames.size())
  140. return -1;
  141. range::sort(fnames, [](const std::string &a, const std::string &b)
  142. {
  143. return boost::lexical_cast<int>(fs::basename(a)) < boost::lexical_cast<int>(fs::basename(b));
  144. });
  145. return boost::lexical_cast<int>(fs::basename(fnames.back()));
  146. }
  147. void Framework::buildLearningSet()
  148. {
  149. vector<Example> examples = loadExamples();
  150. range::sort(examples);
  151. int startExamplesFrom = 0;
  152. ofstream learningLog("log.txt", std::ios::app);
  153. int n = theLastN()+1;
  154. auto armies = learningArmies();
  155. auto bonuese = learningBonuses();
  156. for(int i = 0; i < armies.size(); i++)
  157. {
  158. string army = "army" + toString(i);
  159. for(int j = 0; j < bonuese.size(); j++)
  160. {
  161. Bonus b = bonuese[j];
  162. string bonusStr = "bonus" + toString(j) + describeBonus(b);
  163. for(int k = 0; k < 10; k++)
  164. {
  165. int nHere = n++;
  166. // if(nHere < startExamplesFrom)
  167. // continue;
  168. //
  169. tlog2 << "n="<<nHere<<std::endl;
  170. b.val = k;
  171. Example ex;
  172. ex.i = i;
  173. ex.j = j;
  174. ex.k = k;
  175. ex.art = generateArtWithBonus(b);
  176. ex.dp = generateDuel(armies[i]);
  177. ex.description = army + "\t" + describeBonus(b) + "\t";
  178. if(vstd::contains(examples, ex))
  179. {
  180. string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s#") % nHere % i %j % ex.value % describeBonus(b));
  181. tlog0 << "Already present example, skipping " << msg;
  182. continue;
  183. }
  184. ex.value = rateArt(ex.dp, ex.art);
  185. CSaveFile output("./examples/" + toString(nHere) + ".example");
  186. output << ex;
  187. time_t rawtime;
  188. struct tm * timeinfo;
  189. time ( &rawtime );
  190. timeinfo = localtime ( &rawtime );
  191. string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s# \tdate: %s") % nHere % i %j % ex.value % describeBonus(b) % asctime(timeinfo));
  192. learningLog << msg << flush;
  193. tlog0 << msg;
  194. }
  195. }
  196. }
  197. tlog0 << "Set of learning/testing examples is complete and ready!\n";
  198. }
  199. vector<ArmyDescriptor> Framework::learningArmies()
  200. {
  201. vector<ArmyDescriptor> ret;
  202. //armia zlozona ze stworow z malymi HP-kami
  203. ArmyDescriptor lowHP;
  204. lowHP[0] = CStackBasicDescriptor(1, 9); //halabardier
  205. lowHP[1] = CStackBasicDescriptor(14, 20); //centaur
  206. lowHP[2] = CStackBasicDescriptor(139, 123); //chlop
  207. lowHP[3] = CStackBasicDescriptor(70, 30); //troglodyta
  208. lowHP[4] = CStackBasicDescriptor(42, 50); //imp
  209. //armia zlozona z poteznaych stworow
  210. ArmyDescriptor highHP;
  211. highHP[0] = CStackBasicDescriptor(13, 17); //archaniol
  212. highHP[1] = CStackBasicDescriptor(132, 8); //azure dragon
  213. highHP[2] = CStackBasicDescriptor(133, 10); //crystal dragon
  214. highHP[3] = CStackBasicDescriptor(83, 22); //black dragon
  215. //armia zlozona z tygodniowego przyrostu w zamku
  216. auto &castleTown = VLC->townh->towns[0];
  217. ArmyDescriptor castleNormal;
  218. for(int i = 0; i < 7; i++)
  219. {
  220. auto &cre = VLC->creh->creatures[castleTown.basicCreatures[i]];
  221. castleNormal[i] = CStackBasicDescriptor(cre.get(), cre->growth);
  222. }
  223. castleNormal[5].type = VLC->creh->creatures[52]; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
  224. //armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
  225. auto &rampartTown = VLC->townh->towns[1];
  226. ArmyDescriptor rampartUpgraded;
  227. for(int i = 0; i < 7; i++)
  228. {
  229. auto &cre = VLC->creh->creatures[rampartTown.upgradedCreatures[i]];
  230. rampartUpgraded[i] = CStackBasicDescriptor(cre.get(), cre->growth);
  231. }
  232. rampartUpgraded[5].type = VLC->creh->creatures[52]; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
  233. //armia zlozona z samych strzelcow
  234. ArmyDescriptor shooters;
  235. shooters[0] = CStackBasicDescriptor(35, 17); //arcymag
  236. shooters[1] = CStackBasicDescriptor(41, 1); //titan
  237. shooters[2] = CStackBasicDescriptor(3, 70); //kusznik
  238. shooters[3] = CStackBasicDescriptor(89, 50); //ulepszony ork
  239. ret.push_back(lowHP);
  240. ret.push_back(highHP);
  241. ret.push_back(castleNormal);
  242. ret.push_back(rampartUpgraded);
  243. ret.push_back(shooters);
  244. return ret;
  245. }
  246. vector<Bonus> Framework::learningBonuses()
  247. {
  248. vector<Bonus> ret;
  249. Bonus b;
  250. b.type = Bonus::PRIMARY_SKILL;
  251. b.subtype = PrimarySkill::ATTACK;
  252. ret.push_back(b);
  253. b.subtype = PrimarySkill::DEFENSE;
  254. ret.push_back(b);
  255. b.type = Bonus::STACK_HEALTH;
  256. b.subtype = 0;
  257. ret.push_back(b);
  258. b.type = Bonus::STACKS_SPEED;
  259. ret.push_back(b);
  260. b.type = Bonus::BLOCKS_RETALIATION;
  261. ret.push_back(b);
  262. b.type = Bonus::ADDITIONAL_RETALIATION;
  263. ret.push_back(b);
  264. b.type = Bonus::ADDITIONAL_ATTACK;
  265. ret.push_back(b);
  266. b.type = Bonus::CREATURE_DAMAGE;
  267. ret.push_back(b);
  268. b.type = Bonus::ALWAYS_MAXIMUM_DAMAGE;
  269. ret.push_back(b);
  270. b.type = Bonus::NO_DISTANCE_PENALTY;
  271. ret.push_back(b);
  272. return ret;
  273. }
  274. double Framework::rateArt(const DuelParameters dp, CArtifactInstance * inst)
  275. {
  276. TArtSet setL, setR;
  277. setL[inst->artType->possibleSlots[0]] = inst;
  278. double resultLR = cmpArtSets(dp, setL, setR),
  279. resultRL = cmpArtSets(dp, setR, setL),
  280. resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
  281. //lewa strona z art 0.9
  282. //bez artefaktow -0.41
  283. //prawa strona z art. -0.926
  284. double LRgain = resultLR - resultsBase,
  285. RLgain = resultsBase - resultRL;
  286. return LRgain+RLgain;
  287. }
  288. double Framework::cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
  289. {
  290. dp.sides[0].artifacts = setL;
  291. dp.sides[1].artifacts = setR;
  292. auto battleOutcome = playBattle(dp);
  293. return battleOutcome;
  294. }
  295. double Framework::playBattle(const DuelParameters &dp)
  296. {
  297. string battleFileName = "pliczek.ssnb";
  298. {
  299. CSaveFile out(battleFileName);
  300. out << dp;
  301. }
  302. std::string serverCommand = servername + " " + addQuotesIfNeeded(battleFileName) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
  303. std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir);
  304. std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl;
  305. int code = 0;
  306. boost::thread t([&]
  307. {
  308. code = std::system(serverCommand.c_str());
  309. });
  310. runCommand(runnerCommand, "first_runner", logsDir);
  311. runCommand(runnerCommand, "second_runner", logsDir);
  312. runCommand(runnerCommand, "third_runner", logsDir);
  313. if(withVisualization)
  314. {
  315. //boost::this_thread::sleep(boost::posix_time::millisec(500)); //FIXME
  316. boost::thread tttt(boost::bind(std::system, "VCMI_Client.exe -battle"));
  317. }
  318. //boost::this_thread::sleep(boost::posix_time::seconds(5));
  319. t.join();
  320. return code / 1000000.0;
  321. }
  322. void Framework::runCommand(const std::string &command, const std::string &name, const std::string &logsDir /*= ""*/)
  323. {
  324. static std::string commands[100000];
  325. static int i = 0;
  326. std::string &cmd = commands[i++];
  327. if(logsDir.size() && name.size())
  328. {
  329. std::string directionLogs = logsDir + "/" + name + ".txt";
  330. cmd = command + " > " + addQuotesIfNeeded(directionLogs);
  331. }
  332. else
  333. cmd = command;
  334. boost::thread tt(boost::bind(std::system, cmd.c_str()));
  335. }
  336. DuelParameters Framework::generateDuel(const ArmyDescriptor &ad)
  337. {
  338. DuelParameters dp;
  339. dp.bfieldType = 1;
  340. dp.terType = 1;
  341. auto &side = dp.sides[0];
  342. side.heroId = 0;
  343. side.heroPrimSkills.resize(4,0);
  344. BOOST_FOREACH(auto &stack, ad)
  345. {
  346. side.stacks[stack.first] = DuelParameters::SideSettings::StackSettings(stack.second.type->idNumber, stack.second.count);
  347. }
  348. dp.sides[1] = side;
  349. dp.sides[1].heroId = 1;
  350. return dp;
  351. }
  352. CArtifactInstance * Framework::generateArtWithBonus(const Bonus &b)
  353. {
  354. std::vector<CArtifactInstance*> ret;
  355. static CArtifact *nowy = NULL;
  356. if(!nowy)
  357. {
  358. nowy = new CArtifact();
  359. nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
  360. nowy->name = "Cudowny miecz";
  361. nowy->constituentOf = nowy->constituents = NULL;
  362. nowy->possibleSlots.push_back(Arts::LEFT_HAND);
  363. }
  364. CArtifactInstance *artinst = new CArtifactInstance(nowy);
  365. artinst->addNewBonus(new Bonus(b));
  366. return artinst;
  367. }
  368. class SSN
  369. {
  370. FANN::neural_net net;
  371. struct ParameterSet;
  372. void init(const ParameterSet & params);
  373. FANN::training_data * getTrainingData( const std::vector<Example> &input);
  374. static int ANNCallback(FANN::neural_net &net, FANN::training_data &train, unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs, void *user_data);
  375. static double * genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType);
  376. static const unsigned int num_input = 18;
  377. public:
  378. struct ParameterSet
  379. {
  380. unsigned int neuronsInHidden;
  381. double actSteepHidden, actSteepnessOutput;
  382. FANN::activation_function_enum hiddenActFun, outActFun;
  383. };
  384. SSN();
  385. ~SSN();
  386. //returns mse after learning
  387. double learn(const std::vector<Example> & input, const ParameterSet & params);
  388. double test(const std::vector<Example> & input)
  389. {
  390. auto td = getTrainingData(input);
  391. return net.test_data(*td);
  392. delete td;
  393. }
  394. double run(const DuelParameters &dp, CArtifactInstance * inst);
  395. void save(const std::string &filename);
  396. };
  397. SSN::SSN()
  398. {}
  399. void SSN::init(const ParameterSet & params)
  400. {
  401. const float learning_rate = 0.7f;
  402. const unsigned int num_layers = 3;
  403. const unsigned int num_output = 1;
  404. const float desired_error = 0.01f;
  405. const unsigned int max_iterations = 30000;
  406. const unsigned int iterations_between_reports = 1000;
  407. net.create_standard(num_layers, num_input, params.neuronsInHidden, num_output);
  408. net.set_learning_rate(learning_rate);
  409. net.set_activation_steepness_hidden(params.actSteepHidden);
  410. net.set_activation_steepness_output(params.actSteepnessOutput);
  411. net.set_activation_function_hidden(params.hiddenActFun);
  412. net.set_activation_function_output(params.outActFun);
  413. net.randomize_weights(0.0, 1.0);
  414. }
  415. double SSN::run(const DuelParameters &dp, CArtifactInstance * inst)
  416. {
  417. double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
  418. double * out = net.run(input);
  419. double ret = *out;
  420. free(out);
  421. return ret;
  422. }
  423. int SSN::ANNCallback(FANN::neural_net &net, FANN::training_data &train, unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs, void *user_data)
  424. {
  425. //cout << "Epochs " << setw(8) << epochs << ". "
  426. // << "Current Error: " << left << net.get_MSE() << right << endl;
  427. return 0;
  428. }
  429. double SSN::learn(const std::vector<Example> & input, const ParameterSet & params)
  430. {
  431. init(params);
  432. //FIXME - sypie przy destrukcji
  433. //FANN::training_data td;
  434. FANN::training_data *td = getTrainingData(input);
  435. net.set_callback(ANNCallback, NULL);
  436. net.train_on_data(*td, 1000, 1000, 0.01);
  437. return net.test_data(*td);
  438. }
  439. double * SSN::genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
  440. {
  441. double * ret = new double[num_input];
  442. double * cur = ret;
  443. //general description
  444. *(cur++) = bfieldType/30.0;
  445. *(cur++) = terType/12.0;
  446. //creature & hero description
  447. *(cur++) = dp.heroId/200.0;
  448. for(int k=0; k<4; ++k)
  449. *(cur++) = dp.heroPrimSkills[k]/20.0;
  450. //weighted average of statistics
  451. auto avg = [&](std::function<int(CCreature *)> getter) -> double
  452. {
  453. double ret = 0.0;
  454. int div = 0;
  455. for(int i=0; i<7; ++i)
  456. {
  457. auto & cstack = dp.stacks[i];
  458. if(cstack.count > 0)
  459. {
  460. ret += getter(VLC->creh->creatures[cstack.type]) * cstack.count;
  461. div+=cstack.count;
  462. }
  463. }
  464. return ret/div;
  465. };
  466. *(cur++) = avg([](CCreature * c){return c->attack;})/50.0;
  467. *(cur++) = avg([](CCreature * c){return c->defence;})/50.0;
  468. *(cur++) = avg([](CCreature * c){return c->speed;})/15.0;
  469. *(cur++) = avg([](CCreature * c){return c->hitPoints;})/1000.0;
  470. //bonus description
  471. auto & blist = art->getBonusList();
  472. *(cur++) = blist[0]->type/100.0;
  473. *(cur++) = blist[0]->subtype/10.0;
  474. *(cur++) = blist[0]->val/100.0;;
  475. *(cur++) = art->Attack()/10.0;
  476. *(cur++) = art->Defense()/10.0;
  477. *(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACKS_SPEED))/5.0;
  478. *(cur++) = blist.valOfBonuses(Selector::type(Bonus::STACK_HEALTH))/10.0;
  479. return ret;
  480. }
  481. void SSN::save(const std::string &filename)
  482. {
  483. net.save(filename);
  484. }
  485. SSN::~SSN()
  486. {
  487. }
  488. FANN::training_data * SSN::getTrainingData( const std::vector<Example> &input )
  489. {
  490. FANN::training_data * ret = new FANN::training_data;
  491. double ** inputs = new double *[input.size()];
  492. double ** outputs = new double *[input.size()];
  493. for(int i=0; i<input.size(); ++i)
  494. {
  495. const auto & ci = input[i];
  496. inputs[i] = genSSNinput(ci.dp.sides[0], ci.art, ci.dp.bfieldType, ci.dp.terType);
  497. outputs[i] = new double;
  498. *(outputs[i]) = ci.value;
  499. }
  500. ret->set_train_data(input.size(), num_input, inputs, 1, outputs);
  501. return ret;
  502. }
  503. void SSNRun()
  504. {
  505. //buildLearningSet();
  506. double percentToTrain = 0.8;
  507. auto trainingSet = Framework::loadExamples(false);
  508. std::vector<Example> testSet;
  509. for(int i=0, maxi = trainingSet.size()*(1-percentToTrain); i<maxi; ++i)
  510. {
  511. int ind = rand()%trainingSet.size();
  512. testSet.push_back(trainingSet[ind]);
  513. trainingSet.erase(trainingSet.begin() + ind);
  514. }
  515. SSN network;
  516. SSN::ParameterSet bestParams;
  517. double besttMSE = 1e10;
  518. boost::mt19937 rng;
  519. boost::uniform_01<boost::mt19937> zeroone(rng);
  520. FANN::activation_function_enum possibleFuns[] = {FANN::SIGMOID_SYMMETRIC_STEPWISE, FANN::LINEAR,
  521. FANN::SIGMOID, FANN::SIGMOID_STEPWISE, FANN::SIGMOID_SYMMETRIC};
  522. for(int i=0; i<5000; i += 1)
  523. {
  524. SSN::ParameterSet ps;
  525. ps.actSteepHidden = zeroone() + 0.3;
  526. ps.actSteepnessOutput = zeroone() + 0.3;
  527. ps.neuronsInHidden = rand()%40+10;
  528. ps.hiddenActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
  529. ps.outActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
  530. double lmse = network.learn(trainingSet, ps);
  531. double tmse = network.test(testSet);
  532. if(tmse < besttMSE)
  533. {
  534. besttMSE = tmse;
  535. bestParams = ps;
  536. }
  537. cout << "hid:\t" << i << " lmse:\t" << lmse << " tmse:\t" << tmse << std::endl;
  538. }
  539. //saving of best network
  540. network.learn(trainingSet, bestParams);
  541. network.save("network_config_file.net");
  542. }
  543. int main(int argc, char **argv)
  544. {
  545. std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
  546. po::options_description opts("Allowed options");
  547. opts.add_options()
  548. ("help,h", "Display help and exit")
  549. ("aiLeft,l", po::value<std::string>()->default_value("StupidAI"), "Left AI path")
  550. ("aiRight,r", po::value<std::string>()->default_value("StupidAI"), "Right AI path")
  551. ("battle,b", po::value<std::string>()->default_value("pliczek.ssnb"), "Duel file path")
  552. ("resultsOut,o", po::value<std::string>()->default_value("./results.txt"), "Output file when results will be appended")
  553. ("logsDir,d", po::value<std::string>()->default_value("."), "Directory where log files will be created")
  554. ("visualization,v", "Runs a client to display a visualization of battle");
  555. try
  556. {
  557. po::variables_map vm;
  558. po::store(po::parse_command_line(argc, argv, opts), vm);
  559. po::notify(vm);
  560. if(vm.count("help"))
  561. {
  562. opts.print(std::cout);
  563. prog_help();
  564. return 0;
  565. }
  566. leftAI = vm["aiLeft"].as<std::string>();
  567. rightAI = vm["aiRight"].as<std::string>();
  568. battle = vm["battle"].as<std::string>();
  569. results = vm["resultsOut"].as<std::string>();
  570. logsDir = vm["logsDir"].as<std::string>();
  571. withVisualization = vm.count("visualization");
  572. }
  573. catch(std::exception &e)
  574. {
  575. std::cerr << "Failure during parsing command-line options:\n" << e.what() << std::endl;
  576. exit(1);
  577. }
  578. std::cout << "Config:\n" << leftAI << " vs " << rightAI << " on " << battle << std::endl;
  579. if(leftAI.empty() || rightAI.empty() || battle.empty())
  580. {
  581. std::cerr << "I wasn't able to retreive names of AI or battles. Ending.\n";
  582. return 1;
  583. }
  584. runnername =
  585. #ifdef _WIN32
  586. "VCMI_BattleAiHost.exe"
  587. #else
  588. "./vcmirunner"
  589. #endif
  590. ;
  591. servername =
  592. #ifdef _WIN32
  593. "VCMI_server.exe"
  594. #else
  595. "./vcmiserver"
  596. #endif
  597. ;
  598. VLC = new LibClasses();
  599. VLC->init();
  600. SSNRun();
  601. return EXIT_SUCCESS;
  602. }