瀏覽代碼

[programming challenge, SSN]
* caching and saving desired outputs
* saving of ANN
* full learning, not just one epoch
* stub of reporting of learning function (ANNCallback)

mateuszb 13 年之前
父節點
當前提交
867d01dc34
共有 1 個文件被更改,包括 24 次插入7 次删除
  1. 24 7
      Odpalarka/main.cpp

+ 24 - 7
Odpalarka/main.cpp

@@ -1,6 +1,7 @@
 //#include "../global.h"
 #include "StdInc.h"
 #include "../lib/VCMI_Lib.h"
+#include "boost/tuple/tuple.hpp"
 namespace po = boost::program_options;
 
 
@@ -204,7 +205,16 @@ double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance
 	return ret;
 }
 
-void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters, CArtifactInstance *> > & input)
+int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
+	unsigned int max_epochs, unsigned int epochs_between_reports,
+	float desired_error, unsigned int epochs, void *user_data)
+{
+	//cout << "Epochs     " << setw(8) << epochs << ". "
+	//	<< "Current Error: " << left << net.get_MSE() << right << endl;
+	return 0;
+}
+
+void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > & input)
 {
 	FANN::training_data td;
 
@@ -212,13 +222,13 @@ void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters
 	double ** outputs = new double *[input.size()];
 	for(int i=0; i<input.size(); ++i)
 	{
-		inputs[i] = genSSNinput(input[i].first, input[i].second);
+		inputs[i] = genSSNinput(input[i].get<0>(), input[i].get<1>());
 		outputs[i] = new double;
-		*(outputs[i]) = rateArt(input[i].first, input[i].second);
+		*(outputs[i]) = input[i].get<2>();
 	}
 	td.set_train_data(input.size(), num_input, inputs, 1, outputs);
-	
-	net.train_epoch(td);
+	net.set_callback(ANNCallback, NULL);
+	net.train_on_data(td, 1000, 1000, 0.01);
 }
 
 void initNet(FANN::neural_net & ret)
@@ -317,15 +327,22 @@ void SSNRun()
 	auto arts = genArts(btt);
 
 	//evaluate
-	std::vector<std::pair<DuelParameters, CArtifactInstance *> > setups;
+	std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > setups;
+
+	std::ofstream desOuts("desiredOuts.dat");
+
 	for(int i=0; i<dps.size(); ++i)
 	{
 		for(int j=0; j<arts.size(); ++j)
 		{
-			setups.push_back(std::make_pair(dps[i], arts[j]));
+			setups.push_back(boost::make_tuple(dps[i], arts[j], rateArt(dps[i], arts[i])));
+			desOuts << (*setups.rbegin()).get<2>() << " ";
 		}
+		desOuts << std::endl;
 	}
+
 	learnSSN(network, setups);
+	network.save("network_config_file.net");
 }
 
 int main(int argc, char **argv)