Kaynağa Gözat

[programming challenge, SSN] REPL, various "fixes"

Michał W. Urbańczyk 13 yıl önce
ebeveyn
işleme
82a6520feb
1 değiştirilmiş dosya ile 236 ekleme ve 24 silme
  1. 236 24
      Odpalarka/main.cpp

+ 236 - 24
Odpalarka/main.cpp

@@ -94,6 +94,8 @@ struct Example
 	}
 };
 
+struct SSN_Runner;
+
 class Framework
 {
 	static CArtifactInstance *generateArtWithBonus(const Bonus &b);
@@ -113,6 +115,8 @@ public:
 
 	static void buildLearningSet(); 
 	static vector<Example> loadExamples(bool printInfo = true);
+
+	friend SSN_Runner;
 };
 
 vector<string> Framework::getFileNames(const string &dirname, const std::string &ext)
@@ -149,9 +153,9 @@ vector<Example> Framework::loadExamples(bool printInfo)
 		examples.push_back(ex);
 	}
 
+	tlog0 << "Found " << examples.size() << " examples.\n";
 	if(printInfo)
 	{
-		tlog0 << "Found " << examples.size() << " examples.\n";
 		BOOST_FOREACH(auto &ex, examples)
 		{
 			tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
@@ -471,11 +475,15 @@ public:
 	};
 
 	SSN();
+	SSN(string filename);
 	~SSN();
 
 	//returns mse after learning
 	double learn(const std::vector<Example> & input, const ParameterSet & params);
+	double learn(bool adjustParams = false);
 
+	SSN::ParameterSet getBestParams(vector<Example> &trainingSet);
+	SSN::ParameterSet getBestParams();
 	double test(const std::vector<Example> & input)
 	{
 		auto td = getTrainingData(input);
@@ -485,11 +493,17 @@ public:
 	double run(const DuelParameters &dp, CArtifactInstance * inst); 
 
 	void save(const std::string &filename);
+	void load(const std::string &filename);
 };
 
 SSN::SSN()
 {}
 
+SSN::SSN(string filename)
+{
+	load(filename);
+}
+
 void SSN::init(const ParameterSet & params)
 {
 	const float learning_rate = 0.7f;
@@ -517,7 +531,7 @@ double SSN::run(const DuelParameters &dp, CArtifactInstance * inst)
 	double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
 	double * out = net.run(input);
 	double ret = *out;
-	free(out);
+	//free(out);
 
 	return ret;
 }
@@ -539,7 +553,6 @@ double SSN::learn(const std::vector<Example> & input, const ParameterSet & param
 	net.set_callback(ANNCallback, NULL);
 	net.train_on_data(*td, 1000, 1000, 0.01);
 
-
 // 	int exNum = 130;
 // 
 // 	for(int exNum =0; exNum<input.size(); ++exNum)
@@ -553,6 +566,25 @@ double SSN::learn(const std::vector<Example> & input, const ParameterSet & param
 	return net.test_data(*td);
 }
 
+double SSN::learn(bool adjustParams/* = false*/)
+{
+
+	cout << "Loading examples...\n";
+	auto trainingSet = Framework::loadExamples(false);
+	cout << "Looking for best learning parameters...\n";
+
+
+	auto params = adjustParams ? getBestParams(trainingSet) : getBestParams(); 
+
+	cout << "Learning...\n";
+
+	//saving of best network
+	double finalLmse = learn(trainingSet, params);
+	cout << "Learning done, LMSE=" << finalLmse << endl;
+	save("last_network.net");
+	return finalLmse;
+}
+
 double * SSN::genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
 {
 	double * ret = new double[num_input];
@@ -633,16 +665,17 @@ FANN::training_data * SSN::getTrainingData( const std::vector<Example> &input )
 	return ret;
 }
 
-void SSNRun()
+void SSN::load(const std::string &filename)
 {
-	//Framework::buildLearningSet();
-	double percentToTrain = 0.8;
+	net.create_from_file(filename);
+	cout << "Loaded a network from file " << filename << endl;
+}
 
-	auto trainingSet = Framework::loadExamples(false);
+SSN::ParameterSet SSN::getBestParams(vector<Example> &trainingSet)
+{
+	double percentToTrain = 0.8;
 
 	std::vector<Example> testSet;
-
-
 	for(int i=0, maxi = trainingSet.size()*(1-percentToTrain); i<maxi; ++i)
 	{
 		int ind = rand()%trainingSet.size();
@@ -650,9 +683,6 @@ void SSNRun()
 		trainingSet.erase(trainingSet.begin() + ind);
 	}
 
-	SSN network;
-
-
 	SSN::ParameterSet bestParams;
 	double besttMSE = 1e10;
 
@@ -661,12 +691,6 @@ void SSNRun()
 
 	FANN::activation_function_enum possibleFuns[] = {FANN::SIGMOID_SYMMETRIC_STEPWISE, FANN::LINEAR,
 		FANN::SIGMOID, FANN::SIGMOID_STEPWISE, FANN::SIGMOID_SYMMETRIC};
-// 
-// 	bestParams.actSteepHidden = 0.346;
-// 	bestParams.actSteepnessOutput = 0.449;
-// 	bestParams.hiddenActFun = FANN::SIGMOID_SYMMETRIC;
-// 	bestParams.outActFun = FANN::SIGMOID_SYMMETRIC;
-// 	bestParams.neuronsInHidden = 23;
 
 	for(int i=0; i<5000; i += 1)
 	{
@@ -677,9 +701,9 @@ void SSNRun()
 		ps.hiddenActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
 		ps.outActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
 
-		double lmse = network.learn(trainingSet, ps);
+		double lmse = learn(trainingSet, ps);
 
-		double tmse = network.test(testSet);
+		double tmse = test(testSet);
 		if(tmse < besttMSE)
 		{
 			besttMSE = tmse;
@@ -688,12 +712,199 @@ void SSNRun()
 
 		cout << "hid:\t" << i << " lmse:\t" << lmse << " tmse:\t" << tmse << std::endl;
 	}
-	//saving of best network
-	double debugMSE = network.learn(trainingSet, bestParams);
 
-	network.save("network_config_file.net");
+	return bestParams;
+}
+
+SSN::ParameterSet SSN::getBestParams()
+{
+	// 	bestParams.actSteepHidden = 0.346;
+	// 	bestParams.actSteepnessOutput = 0.449;
+	// 	bestParams.hiddenActFun = FANN::SIGMOID_SYMMETRIC;
+	// 	bestParams.outActFun = FANN::SIGMOID_SYMMETRIC;
+	// 	bestParams.neuronsInHidden = 23;
+
+
+	SSN::ParameterSet params;
+	params.actSteepHidden = 1.18;
+	params.actSteepnessOutput = 1.26;
+	params.hiddenActFun = FANN::SIGMOID_STEPWISE;
+	params.outActFun = FANN::SIGMOID_SYMMETRIC;
+	params.neuronsInHidden = 47;
+	return params;
 }
 
+struct SSN_Runner
+{
+	unique_ptr<SSN> ssn;
+	ArmyDescriptor ad;
+
+	void printHelp()
+	{
+		const char *cmds[] = {"help - prints this info", "create - creates a new ANN, needs to be learned then", "load <file> - loads ANN from file", "save <file> - saves current ANN to file", "learn - runs learning process using examples set", "ask <id> - evaluates given art", "exit - closes application",
+							"army clear - removes current army information", "army add <id> <count> - adds creature to army", "army remove <pos> - removes stack from position",
+							"army print - prints current army state", "army random - generates random army"};
+		cout << "Available commands:\n";
+		BOOST_FOREACH(auto cmd, cmds)
+			cout << "\t" << cmd << endl;
+	}
+
+	int run()
+	{
+		cout << "Welcome to the ANN interactive mode!\n";
+		printHelp();
+
+		while(1)
+		{
+			try
+			{
+				cout << "Please enter your command and press return.\n> ";
+				stringstream ss;
+				string input;
+				getline(cin, input);
+				ss.str(input);
+
+				string command, secondWord;
+				ss >> command >> secondWord;
+
+				if(command == "exit")
+				{
+					cout << "Ending...\n";
+					exit(0);
+				}
+				else if(command == "load")
+				{
+					if(secondWord.empty())
+						secondWord = "last_network.net";
+
+					ssn = unique_ptr<SSN>(new SSN(secondWord));
+				}
+				else if(command == "create")
+				{
+					ssn = unique_ptr<SSN>(new SSN());
+					cout << "Network successfully created. It still needs to be learnt.\n";
+				}
+				else if(command == "help")
+				{
+					printHelp();
+				}
+
+				else if(command == "army" && secondWord.size())
+				{
+					if(secondWord == "clear")
+					{
+						ad.clear();
+						cout << "Army is now empty.\n";
+					}
+					if(secondWord == "print")
+					{
+						cout << "Army contains " << ad.size() << " creatures.\n";
+						BOOST_FOREACH(auto &itr, ad)
+						{
+							cout << itr.first << " => " << itr.second.count << " of " << itr.second.type->namePl << endl;
+						}
+					}
+					if(secondWord == "erase")
+					{
+						int slot;
+						ss >> slot;
+						if(ad.find(slot) != ad.end())
+						{
+							ad.erase(slot);
+							cout << "Slot " << slot << " successfully erased.\n";
+						}
+					}
+					if(secondWord == "add")
+					{
+						int id, count;
+						ss >> id >> count;
+						int i = 0;
+						if(id < 0 || id >= 118)
+						{
+							throw std::runtime_error("Id has to be in <0,118>");
+						}
+						if(count <= 0)
+						{
+							throw std::runtime_error("Count has to be > 0");
+						}
+
+						while(ad.find(i++) != ad.end());
+						if(i >= ARMY_SIZE)
+						{
+							tlog1 << "Cannot add stack, army is full!\n";
+						}
+						else
+						{
+							ad[i] = CStackBasicDescriptor(id, count);
+							tlog0 << "Creature successfully added to slot " << i << endl;;
+						}
+					}
+					if(secondWord == "random")
+					{
+						srand(time(0));
+						ad.clear();
+						int stacks = rand() % 7 + 1;
+						for(int i = 0; i < stacks; i++)
+						{
+							CCreature *c = VLC->creh->creatures[rand() % 118];
+							ad[i] = CStackBasicDescriptor(c, c->growth);
+						}
+						cout << "Generated random army of " << stacks << " creatures.\n";
+					}
+				}
+
+				else if(!ssn)
+				{
+					cout << "Error: you need to create or load ANN from file first!\n";
+					continue;
+				}
+
+				else if(command == "learn")
+				{
+					ssn->learn();
+				}
+				else if(command == "save")
+				{
+					ssn->save(secondWord);
+				}
+				else if(command == "ask")
+				{
+					int artid = boost::lexical_cast<int>(secondWord);
+					CArtifact *art = VLC->arth->artifacts.at(artid);
+
+					DuelParameters dp = Framework::generateDuel(ad);
+
+					CArtifactInstance * artInst = new CArtifactInstance(art);
+					auto bonuses = art->getBonuses([](const Bonus*){return true;});
+					if(!bonuses->size())
+					{
+						tlog1 << "This artifact deosn't provide any bonuses. Please pick another one.";
+					}
+					else
+					{
+						BOOST_FOREACH(auto b, *bonuses)
+							artInst->addNewBonus(new Bonus(*b));
+					
+
+						auto val = ssn->run(dp, artInst);
+						cout << "ANN rates " << art->Name() << " to value = " << val << endl;
+					}
+				}
+				else
+					tlog1 << "Unknown command \""<<command <<"\"!\n";
+			}
+			catch(std::exception &e)
+			{
+				tlog1 << "Encountered error: " << e.what() << endl;
+			}
+			catch(...)
+			{
+				tlog1 << "Encountered unknown error!" << endl;
+			}
+		}
+	}
+};
+
 int main(int argc, char **argv)
 {
 	std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
@@ -764,7 +975,8 @@ int main(int argc, char **argv)
 	VLC = new LibClasses();
 	VLC->init();
 
-	SSNRun();
+	SSN_Runner runner;
+	runner.run();
 
 	return EXIT_SUCCESS;
 }