// File: FactorizationKL.cc -- lee.seung KL algorithm
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include <math.h>

#include "FactorizationKL.h"
#include "objective.h"

/**
 * Peform iterative updates to V, H maximizing the objective function:
 * a log vh - vh --> This is equivalent to minimizing the KL type
 * divergence as described by: a log a/vh - a + vh
 *
 */
int FactorizationKL::perform()
{
  if (getRank() <= 0 or (getRank() > std::min((int)A->size1, (int)A->size2))) {
    std::cerr << "FactorizationKL::perform() Invalid rank: " 
	 << getRank() << std::endl;
    return -1;
  }

  std::cerr << "FactorizationKL::perform()\n";
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  //init_matrices(Factorization::RANDOM);

  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  
  gsl_matrix* T2 = gsl_matrix_calloc(A->size1, A->size2);
  gsl_matrix* T3 = gsl_matrix_alloc(V->size1, V->size2);
  gsl_matrix* T4 = gsl_matrix_alloc(H->size1, H->size2);
  gsl_vector_view colV;
  
  objval = compute_obj(P); // A call to compute obj is impt.

  double objold = objval + 100;
  double error = 1;
  int    iter = 1;

  std::cerr << "Init obj = " << objval << std::endl;
  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("KL: " << iter << " Obj := " << objval << std::endl);

    if (objval > objold) {
      std::cerr << "KL: OBJVAL increased at iteration ==> " << iter << std::endl;
    }
    
    // T2 <== A ./ P
    gsl_matrix_memcpy(T2, A);
    gsl_matrix_div_elements(T2, P);
    

    // T3 = T2*H'
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, T2, H, 0.0, T3);

    // V = V .* T3
    gsl_matrix_mul_elements(V, T3); // V = V .* T3

    // Now normalize columns of V to have unit sum
    for (uint j = 0; j < V->size2; j++) {
      colV = gsl_matrix_column(V, j);
      double sum = gsl_blas_dasum(&colV.vector);
      sum = 1.0/sum;
      gsl_blas_dscal(sum, &colV.vector);
    }

    // P = V*H since
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
    // T2 = A ./ P
    gsl_matrix_memcpy(T2, A);	// How to eliminate this???
    gsl_matrix_div_elements(T2, P);

    // T4 = V'T2
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, T2, 0.0, T4);
    
    // H = H .* T4
    gsl_matrix_mul_elements(H, T4); // H = H .* T4

    // P = V * H
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
    
    if (iter % objmodulo == 0) {
      objold  = objval;
      objval  = compute_obj(P);
      error   = fabs(objval - objold);
    }
    iter++;
  }
  std::cerr << "Final OBJ(A, VH) => " << objval << std::endl;
  gsl_matrix_free(P);
  gsl_matrix_free(T2);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  //cerr << "Relative error => " << fnorm(err)/fnorm(*A) << endl;
  return 0;
}

/**
 * Compute the KL type divergence
 */
double FactorizationKL::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  double diff = 0;
  
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double aij = A->data[i];
    if (aij <= 0) 
      diff = 0; 
    else  
      diff = aij*log(aij/prod->data[i]) - aij + prod->data[i];
    
    if (prod->data[i] == 0 and aij != 0) 
      std::cerr << "NASTINESS" << std::endl;
    obj += diff;
  }
  return obj;
}

double FactorizationKL::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationKL::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
