// File: driver.cc
// Author: Suvrit Sra
// (c) 2004, Suvrit Sra
// Driver code for processing options etc.

#include <iostream>
#include "driver.h"
#include "metric.h"
#include "metricL1.h"
#include "metricL2.h"
#include "metricL2B.h"
#include "metricL1V.h"
#include "metricL2FW.h"
#include "metricKL.h"

#define L1MN 0
#define L2MN 1
#define L2B  2
#define L1V  3
#define L2FW 4
#define KL   5

#define _dbg(x) std::cerr << x << std::endl


const option Driver::long_opts[9] = {
  {"file", 1, 0, 0},
  {"alg", 1, 0, 0},
  {"max", 1, 0, 0},
  {"tol", 1, 0, 0},
  {"help", 0,0,0},
  {"version", 0, 0, 0},
  {"scale", 1, 0, 0},
  {"out", 2, 0, 0},
  {0, 0, 0, 0}
};


void Driver::showHelp()
{
    std::cerr 
        << progname << ": -f FILE [-a [0-5]] [-h]\n\n"
        << "\t-f, --file=FILE      Filename containing non-metric matrix.\n"
        << "\t-a, --alg\n"
        << "\t\t -a 0 ==> L_1 error metric nearness.\n"
        << "\t\t -a 1 ==> L_2 error metric nearness.\n"
        << "\t\t -a 2 ==> L_2 error with Bregman's algo.\n"
        << "\t\t -a 3 ==> L_1 error, vectorized code\n"
        << "\t\t -a 4 ==> L_2 error, Floyd-Warshall order of triangles\n"
        << "\t\t -a 5 ==> KL  error, Bregman's algo\n"
        << "\t-h, --help           Display this help message\n"
        << "\t-m, --max=ITERS      Maximum iterations\n"
        << "\t-o, --out=[FILE]     Write output matrix to FILE (default = stdout)\n"
        << "\t-t, --tol=TOL        Convergence tolerance\n"
        << "\t-s, --scale=SCALE    Scale factor for epsilon in L1 code\n"
        << "\t--version            Display version number and exit.\n";
}


int Driver::processOptions() 
{
    bool cmdok = false;
    algo = L2MN;                // default algo
    int c;

    while (1) {
        int option_idx = 0;
        c = getopt_long(argc, argv, opt_string, long_opts, &option_idx);
    
        if (c == -1)
            break;
        switch (c) {
        case 0: {
            switch (option_idx) {
            case 0:
              if (optarg and *optarg != '\0') {
                cmdok = true;
                strcpy(file, optarg);
              }
              break;
            case 1:
              if (optarg and *optarg != '\0') {
                algo = atoi(optarg);
              }
              break;
            case 2:
              if (optarg and *optarg != '\0') {
                maxiter = atoi(optarg);
              }
              break;
            case 3:
              if (optarg and *optarg != '\0') {
                tol = strtod(optarg, 0);
              }
              break;
            case 4:
              showHelp();
              return 1;
            case 5:
              std::cout << version;
              return -2;
            case 6:
              if (optarg and *optarg != '\0') {
                scale = strtod(optarg, 0);
              }
              break;
            case 7:
              if (optarg and *optarg != '\0') {
                outfile = strdup(optarg);
              }
              output_matrix = true;
              break;
            default:
              std::cerr << "What the hell??" << std::endl;
              break;
            }
        }
          break;
        case 'f':
          if (optarg and *optarg != '\0') {
            cmdok = true;
            strcpy(file, optarg);
          }
          break;
        case 'h':
          showHelp();
          return -1;
        case 'a':
          if (optarg and *optarg != '\0') {
            algo = atoi(optarg);
          }
          break;
        case 'm':
          if (optarg and *optarg != '\0') {
            maxiter = atoi(optarg);
          }
          break;
        case 'o':
          if (optarg and *optarg != '\0') {
            outfile = strdup(optarg);
          }
          output_matrix = true;
          break;
        case 't':
          if (optarg and *optarg != '\0') {
            tol = strtod(optarg, 0);
          }
          break;
        case 's':
            if (optarg and *optarg != '\0') {
                    scale = strtod(optarg, 0);
            }
            break;
        default:
            //std::cerr << "?? getopt returned char " << c << std::endl;
            break;
        }
    }

    if (!cmdok) {
        showHelp();
        return -1;
    }

    return 0;
}


int Driver::execute()
{
  int r = processOptions();

  if (r < 0)
    return r;

    Metric* m = 0;

    switch (algo) {
    case L1MN:
      m = new MetricL1();
      ((MetricL1*)m)->setScaleEps(scale);
      break;
    case L2MN:
      m = new MetricL2();
      break;
    case L2B:
      m = new MetricL2B();
      break;
    case L1V:
      m = new MetricL1V();
      ((MetricL1V*)m)->setScaleEps(scale);
      break;
    case L2FW:
      m = new MetricL2FW();
      break;
    case KL:
      m = new MetricKL();
      break;
    default:
      m = 0;
      break;
    }

    if (m == 0)
      return -1;

    gsl_matrix* d;

    

    if ( (d = read_gsl_matrix(file)) == 0) {
        std::cerr << progname << " error reading matrix\n";
        return -1;
    }  

    m->set_matrix(d);
    // Copies d into D
    m->make_copy();

    std::cout << "Loaded data matrix. " 
              << d->size1 << " x " 
              << d->size2 << "\n";

    m->set_maxiter(maxiter);
    m->set_tol(tol);

    int rv = m->execute();               // Dispatch will happen automatically.
    if (rv < 0) {
      _dbg("Error: Returned code: " << rv);
      return rv;
    } else {
      if (output_matrix) {
        if (outfile) {
          fp = fopen(outfile, "w");
          if (!fp) { 
            _dbg("Could not open: " << outfile << " for writing");
            fp = stdout; 
          }
        }
        for (uint i = 0; i < d->size1; i++) {
          for (uint j = 0; j < d->size2; j++) 
            fprintf(fp, "%f ", gsl_matrix_get(d, i, j));
          fprintf(fp, "\n");
        }
      }
      return rv;
    }
}
