瀏覽代碼

[programming challenge, SSN] Framework for storing and generating learning examples.

Michał W. Urbańczyk 13 年之前
父節點
當前提交
3d16f0a081
共有 2 個文件被更改,包括 363 次插入35 次删除
  1. 3 0
      Odpalarka/StdInc.h
  2. 360 35
      Odpalarka/main.cpp

+ 3 - 0
Odpalarka/StdInc.h

@@ -32,6 +32,9 @@
 #include <boost/thread.hpp>
 #include <boost/bind.hpp>
 #include <boost/program_options.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/algorithm/string/predicate.hpp>
+#include <string>
 
 using boost::format;
 using boost::str;

+ 360 - 35
Odpalarka/main.cpp

@@ -1,8 +1,9 @@
 //#include "../global.h"
 #include "StdInc.h"
 #include "../lib/VCMI_Lib.h"
-#include "boost/tuple/tuple.hpp"
 namespace po = boost::program_options;
+namespace fs = boost::filesystem;
+using namespace std;
 
 
 //FANN
@@ -15,6 +16,221 @@ std::string servername;
 std::string runnername;
 extern DLL_EXPORT LibClasses * VLC;
 
+struct Example
+{
+	//ANN input
+	DuelParameters dp;
+	CArtifactInstance *art;
+	//ANN expected output
+	double value;
+
+	//other
+	std::string description;
+
+	int i, j, k; //helper values for identification
+
+	Example(){}
+	Example(const DuelParameters &Dp, CArtifactInstance *Art, double Value)
+		: dp(Dp), art(Art), value(Value)
+	{}
+
+
+
+	inline bool operator<(const Example & rhs) const
+	{
+		if (k<rhs.k)
+			return true;
+		if (k>rhs.k)
+			return false;
+		if (j<rhs.j)
+			return true;
+		if (j>rhs.j)
+			return false;
+		if (i<rhs.i)
+			return true;
+		if (i>rhs.i)
+			return false;
+		return false;
+	}
+
+	bool operator==(const Example &rhs) const
+	{
+		return rhs.i == i && rhs.j == j && rhs.k == k;
+	}
+
+	template <typename Handler> void serialize(Handler &h, const int version)
+	{
+		h & dp & art & value & description & i & j & k;
+	}
+};
+
+vector<string> getFileNames(const string &dirname = "./examples/", const std::string &ext = "example")
+{
+	vector<string> ret;
+	if(!fs::exists(dirname))
+	{
+		tlog1 << "Cannot find " << dirname << " directory! Will attempt creating it.\n";
+		fs::create_directory(dirname);
+	}
+
+	fs::path tie(dirname);
+	fs::directory_iterator end_iter;
+	for ( fs::directory_iterator file (tie); file!=end_iter; ++file )
+	{
+		if(fs::is_regular_file(file->status())
+			&& boost::ends_with(file->path().filename(), ext))
+		{
+			ret.push_back(file->path().string());
+		}
+	}
+
+	return ret;
+}
+
+vector<Example> loadExamples(bool printInfo = true)
+{
+	std::vector<Example> examples;
+	BOOST_FOREACH(auto fname, getFileNames("./examples/", "example"))
+	{
+		CLoadFile loadf(fname);
+		Example ex;
+		loadf >> ex;
+		examples.push_back(ex);
+	}
+
+	if(printInfo)
+	{
+		tlog0 << "Found " << examples.size() << " examples.\n";
+		BOOST_FOREACH(auto &ex, examples)
+		{
+			tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
+		}
+	}
+
+	return examples;
+}
+
+bool matchExample(const Example &ex,  int i, int j, int k)
+{
+	return ex.i == i && ex.j == j && ex.k == k;
+}
+
+//generates simple duel where both sides have given army
+DuelParameters generateDuel(const ArmyDescriptor &ad)
+{
+	DuelParameters dp;
+	dp.bfieldType = 1;
+	dp.terType = 1;
+
+	auto &side = dp.sides[0];
+	side.heroId = 0;
+	side.heroPrimSkills.resize(4,0);
+	BOOST_FOREACH(auto &stack, ad)
+	{
+		side.stacks[stack.first] = DuelParameters::SideSettings::StackSettings(stack.second.type->idNumber, stack.second.count);
+	}
+	dp.sides[1] = side;
+	dp.sides[1].heroId = 1;
+	return dp;
+}
+
+std::vector<ArmyDescriptor> learningArmies()
+{
+	std::vector<ArmyDescriptor> ret;
+	
+	//armia zlozona ze stworow z malymi HP-kami
+	ArmyDescriptor lowHP;
+	lowHP[0] = CStackBasicDescriptor(1, 9); //halabardier
+	lowHP[1] = CStackBasicDescriptor(14, 20); //centaur
+	lowHP[2] = CStackBasicDescriptor(139, 123); //chlop
+	lowHP[3] = CStackBasicDescriptor(70, 30); //troglodyta
+	lowHP[4] = CStackBasicDescriptor(42, 50); //imp
+
+	//armia zlozona z poteznaych stworow
+	ArmyDescriptor highHP;
+	highHP[0] = CStackBasicDescriptor(13, 17); //archaniol
+	highHP[1] = CStackBasicDescriptor(132, 8); //azure dragon
+	highHP[2] = CStackBasicDescriptor(133, 10); //crystal dragon
+	highHP[3] = CStackBasicDescriptor(83, 22); //black dragon
+
+	//armia zlozona z tygodniowego przyrostu w zamku
+	auto &castleTown = VLC->townh->towns[0];
+	ArmyDescriptor castleNormal;
+	for(int i = 0; i < 7; i++)
+	{
+		auto &cre = VLC->creh->creatures[castleTown.basicCreatures[i]];
+		castleNormal[i] = CStackBasicDescriptor(cre.get(), cre->growth);
+	}
+	castleNormal[5].type = VLC->creh->creatures[52]; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
+
+	//armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
+	auto &rampartTown = VLC->townh->towns[1];
+	ArmyDescriptor rampartUpgraded;
+	for(int i = 0; i < 7; i++)
+	{
+		auto &cre = VLC->creh->creatures[rampartTown.upgradedCreatures[i]];
+		rampartUpgraded[i] = CStackBasicDescriptor(cre.get(), cre->growth);
+	}
+	rampartUpgraded[5].type = VLC->creh->creatures[52]; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
+
+	//armia zlozona z samych strzelcow
+	ArmyDescriptor shooters;
+	shooters[0] = CStackBasicDescriptor(35, 17); //arcymag
+	shooters[1] = CStackBasicDescriptor(41, 1); //titan
+	shooters[2] = CStackBasicDescriptor(3, 70); //kusznik
+	shooters[3] = CStackBasicDescriptor(89, 50); //ulepszony ork
+
+
+
+	ret.push_back(lowHP);
+	ret.push_back(highHP);
+	ret.push_back(castleNormal);
+	ret.push_back(rampartUpgraded);
+	ret.push_back(shooters);
+	return ret;
+}
+
+std::vector<Bonus> learningBonuses()
+{
+	std::vector<Bonus> ret;
+
+
+	Bonus b;
+	b.type = Bonus::PRIMARY_SKILL;
+	b.subtype = PrimarySkill::ATTACK;
+	ret.push_back(b);
+
+	b.subtype = PrimarySkill::DEFENSE;
+	ret.push_back(b);
+
+	b.type = Bonus::STACK_HEALTH;
+	b.subtype = 0;
+	ret.push_back(b);
+
+	b.type = Bonus::STACKS_SPEED;
+	ret.push_back(b);
+
+	b.type = Bonus::BLOCKS_RETALIATION;
+	ret.push_back(b);
+
+	b.type = Bonus::ADDITIONAL_RETALIATION;
+	ret.push_back(b);
+
+	b.type = Bonus::ADDITIONAL_ATTACK;
+	ret.push_back(b);
+
+	b.type = Bonus::CREATURE_DAMAGE;
+	ret.push_back(b);
+
+	b.type = Bonus::ALWAYS_MAXIMUM_DAMAGE;
+	ret.push_back(b);
+
+	b.type = Bonus::NO_DISTANCE_PENALTY;
+	ret.push_back(b);
+
+	return ret;
+}
+
 std::string addQuotesIfNeeded(const std::string &s)
 {
 	if(s.find_first_of(' ') != std::string::npos)
@@ -30,7 +246,7 @@ void prog_help()
 
 void runCommand(const std::string &command, const std::string &name, const std::string &logsDir = "")
 {
-	static std::string commands[100];
+	static std::string commands[100000];
 	static int i = 0;
 	std::string &cmd = commands[i++];
 	if(logsDir.size() && name.size())
@@ -46,13 +262,14 @@ void runCommand(const std::string &command, const std::string &name, const std::
 
 double playBattle(const DuelParameters &dp)
 {
+	string battleFileName = "pliczek.ssnb";
 	{
-		CSaveFile out("pliczek.ssnb");
+		CSaveFile out(battleFileName);
 		out << dp;
 	}
 
 
-	std::string serverCommand = servername + " " + addQuotesIfNeeded(battle) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
+	std::string serverCommand = servername + " " + addQuotesIfNeeded(battleFileName) + " " + addQuotesIfNeeded(leftAI) + " " + addQuotesIfNeeded(rightAI) + " " + addQuotesIfNeeded(results) + " " + addQuotesIfNeeded(logsDir) + " " + (withVisualization ? " v" : "");
 	std::string runnerCommand = runnername + " " + addQuotesIfNeeded(logsDir);
 	std::cout <<"Server command: " << serverCommand << std::endl << "Runner command: " << runnerCommand << std::endl;
 
@@ -81,10 +298,6 @@ typedef std::map<int, CArtifactInstance*> TArtSet;
 
 double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
 {
-	//lewa strona z art 0.9
-	//bez artefaktow -0.41
-	//prawa strona z art. -0.926
-
 	dp.sides[0].artifacts = setL;
 	dp.sides[1].artifacts = setR;
 
@@ -92,23 +305,32 @@ double cmpArtSets(DuelParameters dp, TArtSet setL, TArtSet setR)
 	return battleOutcome;
 }
 
-std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive)
+CArtifactInstance *generateArtWithBonus(const Bonus &b)
 {
 	std::vector<CArtifactInstance*> ret;
 
-	CArtifact *nowy = new CArtifact();
-	nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
-	nowy->name = "Cudowny miecz";
-	nowy->constituentOf = nowy->constituents = NULL;
-	nowy->possibleSlots.push_back(Arts::LEFT_HAND);
+	static CArtifact *nowy = NULL;
+	
+	if(!nowy)
+	{
+		nowy = new CArtifact();
+		nowy->description = "Cudowny miecz Towa gwarantuje zwyciestwo";
+		nowy->name = "Cudowny miecz";
+		nowy->constituentOf = nowy->constituents = NULL;
+		nowy->possibleSlots.push_back(Arts::LEFT_HAND);
+	}
 
+	CArtifactInstance *artinst = new CArtifactInstance(nowy);
+	artinst->addNewBonus(new Bonus(b));
+	return artinst;
+}
 
+std::vector<CArtifactInstance*> genArts(const std::vector<Bonus> & bonusesToGive)
+{
+	std::vector<CArtifactInstance*> ret;
 	BOOST_FOREACH(auto b, bonusesToGive)
 	{
-		CArtifactInstance *artinst = new CArtifactInstance(nowy);
-		auto &arts = VLC->arth->artifacts;
-		artinst->addNewBonus(new Bonus(b));
-		ret.push_back(artinst);
+		ret.push_back(generateArtWithBonus(b));
 	}
 
 // 	auto bonuses = artinst->getBonuses([](const Bonus *){ return true; });
@@ -130,8 +352,14 @@ double rateArt(const DuelParameters dp, CArtifactInstance * inst)
 		resultRL = cmpArtSets(dp, setR, setL),
 		resultsBase = cmpArtSets(dp, TArtSet(), TArtSet());
 
+
+
+	//lewa strona z art 0.9
+	//bez artefaktow -0.41
+	//prawa strona z art. -0.926
+
 	double LRgain = resultLR - resultsBase,
-		RLgain = resultRL - resultsBase;
+		RLgain = resultsBase - resultRL;
 	return LRgain+RLgain;
 }
 
@@ -214,7 +442,7 @@ int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
 	return 0;
 }
 
-void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > & input)
+void learnSSN(FANN::neural_net & net, const std::vector<Example> & input)
 {
 	FANN::training_data td;
 
@@ -222,9 +450,9 @@ void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParamet
 	double ** outputs = new double *[input.size()];
 	for(int i=0; i<input.size(); ++i)
 	{
-		inputs[i] = genSSNinput(input[i].get<0>(), input[i].get<1>());
+		inputs[i] = genSSNinput(input[i].dp, input[i].art);
 		outputs[i] = new double;
-		*(outputs[i]) = input[i].get<2>();
+		*(outputs[i]) = input[i].value;
 	}
 	td.set_train_data(input.size(), num_input, inputs, 1, outputs);
 	net.set_callback(ANNCallback, NULL);
@@ -245,7 +473,7 @@ void initNet(FANN::neural_net & ret)
 
 	ret.set_learning_rate(learning_rate);
 
-	ret.set_activation_steepness_hidden(1.0);
+	ret.set_activation_steepness_hidden(0.9);
 	ret.set_activation_steepness_output(1.0);
 
 	ret.set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE);
@@ -286,20 +514,14 @@ void SSNRun()
 // 	}
 
 
+
+
 	//duels to test on
 	std::vector<DuelParameters> dps;
 	for(int k = 0; k<10; ++k)
 	{
 		DuelParameters dp;
-		dp.bfieldType = 1;
-		dp.terType = 1;
-
-		auto &side = dp.sides[0];
-		side.heroId = 0;
-		side.heroPrimSkills.resize(4,0);
-		side.stacks[0] = DuelParameters::SideSettings::StackSettings(10+k*3, rand()%30);
-		dp.sides[1] = side;
-		dp.sides[1].heroId = 1;
+
 		dps.push_back(dp);
 	}
 
@@ -307,6 +529,14 @@ void SSNRun()
 	for(int i=0; i<5; ++i)
 	{
 		Bonus b;
+		b.additionalInfo = -1;
+		b.duration = Bonus::PERMANENT;
+		b.source = Bonus::ARTIFACT;
+		b.sid = 0;
+		b.turnsRemain = 0xda;
+		b.valType = Bonus::ADDITIVE_VALUE;
+		b.effectRange = Bonus::NO_LIMIT;
+
 		b.type = Bonus::PRIMARY_SKILL;
 		b.subtype = PrimarySkill::ATTACK;
 		b.val = 5 * i + 1;
@@ -327,7 +557,7 @@ void SSNRun()
 	auto arts = genArts(btt);
 
 	//evaluate
-	std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > setups;
+	std::vector<Example> setups;
 
 	std::ofstream desOuts("desiredOuts.dat");
 
@@ -335,8 +565,8 @@ void SSNRun()
 	{
 		for(int j=0; j<arts.size(); ++j)
 		{
-			setups.push_back(boost::make_tuple(dps[i], arts[j], rateArt(dps[i], arts[i])));
-			desOuts << (*setups.rbegin()).get<2>() << " ";
+			setups.push_back(Example(dps[i], arts[j], rateArt(dps[i], arts[i])));
+			desOuts << (*setups.rbegin()).value << " ";
 		}
 		desOuts << std::endl;
 	}
@@ -345,6 +575,98 @@ void SSNRun()
 	network.save("network_config_file.net");
 }
 
+string toString(int i)
+{
+	return boost::lexical_cast<string>(i);
+}
+
+string describeBonus(const Bonus &b)
+{
+	return "+" + toString(b.val) + "_to_" + bonusTypeToString(b.type)+"_sub"+toString(b.subtype);
+}
+
+int theLastN()
+{
+	auto fnames = getFileNames();
+	if(!fnames.size())
+		return -1;
+
+	range::sort(fnames, [](const std::string &a, const std::string &b)
+	{
+		return boost::lexical_cast<int>(fs::basename(a)) < boost::lexical_cast<int>(fs::basename(b));
+	});
+
+	return boost::lexical_cast<int>(fs::basename(fnames.back()));
+}
+
+void buildLearningSet()
+{
+	vector<Example> examples = loadExamples();
+	range::sort(examples);
+
+
+	int startExamplesFrom = 0;
+	ofstream learningLog("log.txt", std::ios::app);
+
+	int n = theLastN()+1;
+
+	auto armies = learningArmies();
+	auto bonuese = learningBonuses();
+
+	for(int i = 0; i < armies.size(); i++)
+	{
+		string army = "army" + toString(i);
+		for(int j = 0; j < bonuese.size(); j++)
+		{
+			Bonus b = bonuese[j];
+			string bonusStr = "bonus" + toString(j) + describeBonus(b);
+			for(int k = 0; k < 10; k++)
+			{
+				int nHere = n++;
+
+// 				if(nHere < startExamplesFrom)
+// 					continue;
+// 					
+
+
+				tlog2 << "n="<<nHere<<std::endl;
+				b.val = k;
+
+				Example ex;
+				ex.i = i;
+				ex.j = j;
+				ex.k = k;
+				ex.art = generateArtWithBonus(b);
+				ex.dp = generateDuel(armies[i]);
+				ex.description = army + "\t" + describeBonus(b) + "\t";
+
+				if(vstd::contains(examples, ex))
+				{
+					string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s#") % nHere % i %j % ex.value % describeBonus(b));
+					tlog0 << "Already present example, skipping " << msg;
+					continue;
+				}
+
+				ex.value = rateArt(ex.dp, ex.art);
+				
+				CSaveFile output("./examples/" + toString(nHere) + ".example");
+				output << ex;
+				time_t rawtime;
+				struct tm * timeinfo;
+				time ( &rawtime );
+				timeinfo = localtime ( &rawtime );
+				string msg = str(format("n=%d \tarmy %d \tbonus %d \tresult %lf \t Bonus#%s# \tdate: %s") % nHere % i %j % ex.value % describeBonus(b) % asctime(timeinfo));
+				learningLog << msg << flush;
+				tlog0 << msg;
+			}
+		}
+	}
+
+	tlog0 << "Set of learning/testing examples is complete and ready!\n";
+}
+
+
+
 int main(int argc, char **argv)
 {
 	std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
@@ -415,6 +737,9 @@ int main(int argc, char **argv)
 	VLC = new LibClasses();
 	VLC->init();
 
+
+	buildLearningSet();
+
 	SSNRun();
 
 	return EXIT_SUCCESS;