const char *help = "\
BAG (c) Trebolloc & Co 2001\n\
\n\
This program will train a bagging of MLPs with tanh outputs for\n\
classification and linear outputs for regression\n";

#include "FileDataSet.h"
#include "MseCriterion.h"
#include "Tanh.h"
#include "MseMeasurer.h"
#include "ClassMeasurer.h"
#include "TwoClassFormat.h"
#include "StochasticGradient.h"
#include "GMTrainer.h"
#include "CmdLine.h"
#include "MLP.h"
#include "WeightedSumMachine.h"
#include "Bagging.h"

using namespace Torch;

int main(int argc, char **argv)
{
  char *train_file;
  char *test_file;
  int n_inputs;
  int n_targets;
  int n_hu;
  int n_bag;
  int max_load_train;
  int max_load_test;
  int seed_value;
  real accuracy;
  real learning_rate;
  real decay;
  int max_iter;
  bool regression;
  char *dir_name;
  char *load_model;
  char *save_model;

  bool norm_out;

  //=================== The command-line ==========================

  CmdLine cmd;

  cmd.info(help);

  cmd.addText("\nArguments:");
  cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");
  cmd.addICmdArg("n_targets", &n_targets, "output dimension of the data");
  cmd.addSCmdArg("train_file", &train_file, "the train file");
  cmd.addSCmdArg("test_file", &test_file, "the test file");

  cmd.addText("\nModel Options:");
  cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");
  cmd.addICmdOption("-nbag_max", &n_bag, 10, "maximum number of MLP in the bag");
  cmd.addBCmdOption("-rm", &regression, false, "regression mode");

  cmd.addText("\nLearning Options:");
  cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");
  cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");
  cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
  cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");

  cmd.addText("\nMisc Options:");
  cmd.addICmdOption("-load_train", &max_load_train, -1, "max number of train examples to load");
  cmd.addICmdOption("-load_test", &max_load_test, -1, "max number of test examples to load");
  cmd.addICmdOption("-seed", &seed_value, -1, "initial seed for random generator");
  cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
  cmd.addSCmdOption("-lm", &load_model, "", "start from given model file");
  cmd.addSCmdOption("-sm", &save_model, "", "save results into given model file");

  cmd.addBCmdOption("-norm_out", &norm_out, false, "normalize targets");

  cmd.read(argc, argv);

  if (seed_value == -1)
    seed();
  else
    manual_seed((long)seed_value);

  //=================== DataSets ===================

  FileDataSet data(train_file, n_inputs, n_targets, false, max_load_train);
  data.setBOption("normalize inputs", true);
  if (norm_out)
    data.setBOption("normalize targets", true);
  data.init();

  FileDataSet test_data(test_file, n_inputs, n_targets, false, max_load_test);
  test_data.init();
  test_data.normalizeUsingDataSet(&data);

  // how is the class encoded in the datasets: TwoClassFormat
  TwoClassFormat* class_format=NULL;
  if (!regression)
    class_format = new TwoClassFormat(&data);

  //=================== The Model ==================
  // there will be "n_bag" MLPs created for the Bagging. For each
  // of these MLP, we want to train them, using a MSE criterion and a 
  // Stochastic gradient optimizer, given to a Trainer.

  MLP **mlp = new MLP *[n_bag];
  MseCriterion **mse = new MseCriterion *[n_bag];
  StochasticGradient **opt = new StochasticGradient *[n_bag];
  Trainer **trainer = new Trainer *[n_bag];
  for (int i=0;i<n_bag;i++) {
    mlp[i] = new MLP(n_inputs,n_hu,n_targets);
    mlp[i]->setBOption("tanh outputs",!regression);
    mlp[i]->init();

    mse[i] = new MseCriterion(n_targets);
    mse[i]->init();

    opt[i] = new StochasticGradient();
    opt[i]->setIOption("max iter", max_iter);
    opt[i]->setROption("end accuracy", accuracy);
    opt[i]->setROption("learning rate", learning_rate);
    opt[i]->setROption("learning rate decay", decay);

    trainer[i] = (Trainer*)new GMTrainer(mlp[i], &data, mse[i], opt[i]);
  }

  WeightedSumMachine bagmachine(trainer,n_bag,NULL);
  bagmachine.init();

  // We also want to measure the performance of the Bagging itself

  List *measurers = NULL;

  char bag_mse_name[100];
  sprintf(bag_mse_name,"%s/the_bag_mse%d",dir_name,n_hu);
  MseMeasurer msebag(bagmachine.outputs,&data, bag_mse_name);
  msebag.init();
  addToList(&measurers,1,&msebag);

  char bag_test_mse_name[100];
  sprintf(bag_test_mse_name,"%s/test_bag_mse%d",dir_name,n_hu);
  MseMeasurer msebag_test(bagmachine.outputs,&test_data, bag_test_mse_name);
  msebag_test.init();
  addToList(&measurers,1,&msebag_test);

  char bag_class_name[100];
  sprintf(bag_class_name,"%s/the_bag_class_err%d",dir_name,n_hu);
  ClassMeasurer classbag(bagmachine.outputs,&data, class_format,bag_class_name);
  classbag.init();
  if(!regression && (n_targets == 1))
    addToList(&measurers,1,&classbag);

  char bag_test_class_name[100];
  sprintf(bag_test_class_name,"%s/test_bag_class_err%d",dir_name,n_hu);
  ClassMeasurer classbag_test(bagmachine.outputs,&test_data, class_format,bag_test_class_name);
  classbag_test.init();
  if(!regression && (n_targets == 1))
    addToList(&measurers,1,&classbag_test);

  Bagging bagging(&bagmachine,&data);


  // =========== Training and/or Testing ===========

  if (strcmp(load_model,"")) {
    char load_model_name[100];
    sprintf(load_model_name,"%s/%s",dir_name,load_model);
    bagging.load(load_model_name);
    for (int i=0;i<n_bag;i++) {
      bagmachine.n_trainers = i;
      bagging.n_trainers = i;
      bagging.test(measurers);
    }
  } else {
    bagging.train(measurers);
  }
  if (strcmp(save_model,"")) {
    char save_model_name[100];
    sprintf(save_model_name,"%s/%s",dir_name,save_model);
    bagging.save(save_model_name);
  }

  //=================== The End =====================

  for (int i=0;i<n_bag;i++) {
    delete mlp[i];
    delete mse[i];
    delete trainer[i];
    delete opt[i];
  }

  if (!regression)
    delete class_format;

  delete[] mlp;
  delete[] mse;
  delete[] trainer;
  delete[] opt;

  freeList(&measurers);

  return(0);
}
