|
@@ -1,6 +1,7 @@
|
|
|
//#include "../global.h"
|
|
|
#include "StdInc.h"
|
|
|
#include "../lib/VCMI_Lib.h"
|
|
|
+#include "boost/tuple/tuple.hpp"
|
|
|
namespace po = boost::program_options;
|
|
|
|
|
|
|
|
@@ -204,7 +205,16 @@ double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
-void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters, CArtifactInstance *> > & input)
|
|
|
+int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
|
|
|
+ unsigned int max_epochs, unsigned int epochs_between_reports,
|
|
|
+ float desired_error, unsigned int epochs, void *user_data)
|
|
|
+{
|
|
|
+ //cout << "Epochs " << setw(8) << epochs << ". "
|
|
|
+ // << "Current Error: " << left << net.get_MSE() << right << endl;
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > & input)
|
|
|
{
|
|
|
FANN::training_data td;
|
|
|
|
|
@@ -212,13 +222,13 @@ void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters
|
|
|
double ** outputs = new double *[input.size()];
|
|
|
for(int i=0; i<input.size(); ++i)
|
|
|
{
|
|
|
- inputs[i] = genSSNinput(input[i].first, input[i].second);
|
|
|
+ inputs[i] = genSSNinput(input[i].get<0>(), input[i].get<1>());
|
|
|
outputs[i] = new double;
|
|
|
- *(outputs[i]) = rateArt(input[i].first, input[i].second);
|
|
|
+ *(outputs[i]) = input[i].get<2>();
|
|
|
}
|
|
|
td.set_train_data(input.size(), num_input, inputs, 1, outputs);
|
|
|
-
|
|
|
- net.train_epoch(td);
|
|
|
+ net.set_callback(ANNCallback, NULL);
|
|
|
+ net.train_on_data(td, 1000, 1000, 0.01);
|
|
|
}
|
|
|
|
|
|
void initNet(FANN::neural_net & ret)
|
|
@@ -317,15 +327,22 @@ void SSNRun()
|
|
|
auto arts = genArts(btt);
|
|
|
|
|
|
//evaluate
|
|
|
- std::vector<std::pair<DuelParameters, CArtifactInstance *> > setups;
|
|
|
+ std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > setups;
|
|
|
+
|
|
|
+ std::ofstream desOuts("desiredOuts.dat");
|
|
|
+
|
|
|
for(int i=0; i<dps.size(); ++i)
|
|
|
{
|
|
|
for(int j=0; j<arts.size(); ++j)
|
|
|
{
|
|
|
- setups.push_back(std::make_pair(dps[i], arts[j]));
|
|
|
+ setups.push_back(boost::make_tuple(dps[i], arts[j], rateArt(dps[i], arts[i])));
|
|
|
+ desOuts << (*setups.rbegin()).get<2>() << " ";
|
|
|
}
|
|
|
+ desOuts << std::endl;
|
|
|
}
|
|
|
+
|
|
|
learnSSN(network, setups);
|
|
|
+ network.save("network_config_file.net");
|
|
|
}
|
|
|
|
|
|
int main(int argc, char **argv)
|