// FactorizationAKL.cc -- implement AKL algorithm
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.


#include <iostream>
#include <gsl/gsl_blas.h>
#include <math.h>

#include "FactorizationAKL.h"
#include "objective.h"


int FactorizationAKL::perform()
{
  if (getRank() <= 0 or (getRank() > std::min((int)A->size1, (int)A->size2))) {
    std::cerr << "FactorizationAKL::perform() Invalid rank: " 
	 << getRank() << std::endl;
    return -1;
  }
  
  DBG("FactorizationAKL::perform()\n");
  DBG("UNDER CONSTRUCTION\n");

  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  
  objval = compute_obj(P); // A call to compute obj is impt.

  double objold = objval + 100;
  double error = 1;
  int    iter = 0;

  //DBG("Init obj = " << objval << endl);
  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("AKL: " << iter << " Obj := " << objval << std::endl);

    if (objval > objold) {
      DBG("AKL: OBJVAL increased at iteration ==> " << iter << std::endl);
    }
   
    // Fix V and compute H so that we min D(A || VH)
    //mykl(A, V, H, arg1);

    // Fix H and compute V so that we min D(A || VH)
    // mykl(A, V, H, arg2)
    
    if (iter % objmodulo == 0) {
      objold  = objval;
      objval  = compute_obj(P);
      error   = fabs(objval - objold);
    }
    ++iter;
  }
  set_relerror(10);		// Just A PLACEHOLDER!!!!!!!!!!!!!!!
  return 0;
}


/**
 * Compute the KL type divergence
 */
double FactorizationAKL::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  double diff = 0;
  
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double aij = A->data[i];
    if (aij <= 0) 
      diff = 0; 
    else  
      diff = aij*log(aij/prod->data[i]) - aij + prod->data[i];
    
    if (prod->data[i] == 0 and aij != 0) 
      std::cerr << "NASTINESS" << std::endl;
    obj += diff;
  }
  return obj;
}

double FactorizationAKL::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationAKL::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
