// File: metricKL.cc
// Author: Suvrit Sra
// (c) 2004, Suvrit Sra
// Impl. of KL metric nearness

#include "metricKL.h"

/**
 * Perform fast KL metric nearness
 *
 */
int MetricKL::execute()
{
  std::cout << "MetricKL::execute() Entered\n";

  uint n = d->size1;
  uint N = n * (n - 1)/2;

  uint r = N * (n - 2);

  // Allocate z
  if (alloc_dual_variables() < 0) {
    return -2;
  }
  double zsum = 10;
  uint iter    = 1;
  if (maxiter == 0)
    maxiter = 5*N;

  // Record time here
  clock_t elapse = clock();
  uint t;                     // Current triangle


  while (zsum > tol && iter < maxiter) {
    //std::cout << "Outer iteration: " << iter
    //        << ", zsum = " << zsum << std::endl;
    zsum = 0.0;
    // t index the current triangle ...
    t = 0;
    // Go thru all the triangles in the foll. order
    // Perhaps better to go thru triangles in some random order.
    // 012, 120, 201, 013, 130, 301,...
    // The following order seems very bad. Perhaps a FW type ordering.
    for (uint i = 0; i < n; i++) {
      for (uint j = i+1; j < n; j++) {
        for (uint k = j+1; k < n; k++) {
          fixOneTriangle(i, j, k, t, zsum);
        }
      }
    }
    ++iter;
  }

  elapse = clock() - elapse;
  unsigned long int flops = iter * r * 20;
  double ns = elapse * 1.0 / CLOCKS_PER_SEC;

  double obj = compute_obj();

  // double relerror = sqrt(obj) / fnorm(D);

  std::cout << "MetricKL::execute(): Elapsed time = " 
            << ns   << " secs, "
            << "Iterations:= " << iter
            << ", zsum:= " << zsum << "\n";
  std::cout << "Objective value (INCORRECT) = " << obj 
            << ", FLOPS = " << flops << std::endl;
  //<< ", RelError = " << relerror << "\n";
  //printon(d, std::cout);
  std::cout << "SUMMARY:" 
            << n << " " << ns 
            << " " << obj << std::endl;
  return 0;
}


/**
 * This procedure is different from that in other files because it
 * fixes all the three directed triangles at the same time. When we enter
 * it, we have i < j < k guaranteed.
 */
double MetricKL::fixOneTriangle(uint i, uint j, 
                                uint k, uint& t, double& zsum) 
{
  double ab, bc, ca;
  double v1, v2, v3;
  double mu1, mu2, mu3;
  double theta,y;

  ab = gsl_matrix_get( d, i, j );
  bc = gsl_matrix_get( d, j, k );
  ca = gsl_matrix_get( d, i, k );

  v1 = (ca + bc)/ab;
  v2 = (ca + ab)/bc;
  v3 = (ab + bc)/ca;

  if (ab != 0) {
    mu1 = 0.5*log(v1);
    theta = std::min(mu1, z[t]);
    y = exp(theta);
    ab  *= y;
    bc  /= y;
    ca  /= y;
    z[t] -= theta;
    ++t;
    zsum += fabs(theta);
    gsl_matrix_set(d, i, j, ab );
  }

  if (bc != 0) {
    mu2 = 0.5*log(v2);
    theta = std::min(mu2, z[t]);
    y = exp(theta);
    bc *= y;
    ab /= y;
    ca /= y;
    z[t] -= theta;
    ++t;
    zsum += fabs(theta);
    gsl_matrix_set(d, j, k, bc );
  }

  if (ca != 0) {
    mu3 = 0.5*log(v3);
    theta = std::min(mu3, z[t]);
    y = exp(theta);
    ca *= y;
    bc /= y;
    ab /= y;
    z[t] -= theta;
    ++t;
    zsum += fabs(theta);
    gsl_matrix_set(d, i, k, ca );
  }
  return zsum;
}


inline double MetricKL::compute_obj ()
{
  double obj = 0;
  
  return 0.5*obj;
}
