// File: FactorizationSKL.cc -- lee.seung KL for sparse
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include <math.h>

#include "FactorizationSKL.h"
#include "objective.h"

/**
 * Peform iterative updates to V, H maximizing the objective function:
 * a log vh - vh --> This is equivalent to minimizing the KL type
 * divergence as described by: a log a/vh - a + vh
 *
 * NOTE: This perform is tuned to work for a sparse A called sA
 */
int FactorizationSKL::perform()
{
  if (getRank() <= 0 or getRank() > sA->numCols()) {
    std::cerr << "FactorizationKL::perform() Invalid rank: "  << getRank() << std::endl;
    return -1;
  }

  std::cerr << "FactorizationKL::perform()\n";

  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  //init_matrices(Factorization::RANDOM);	// uses default seed in m_seed

  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);

  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  
  gsl_matrix* T3 = gsl_matrix_alloc(V->size1, V->size2);
  gsl_matrix* T4 = gsl_matrix_alloc(H->size1, H->size2);
  gsl_vector_view colV;

  SparseMatrix* T2 = sA->clone();

  objval = sA->compute_obj(P, OBJ_KL); // A call to compute obj is impt.

  double objold = objval + 10;
  double error = 1;
  int    iter = 1;

  std::cerr << "Init Objective := " << objval << std::endl;
  
  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("KL: " << iter << " Obj := " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "KL: OBJVAL increased. Something's wrong\n";
    }

    sA->dotdiv(P, T2);		   // T2 = A ./ P
    T2->lmult(H, T3, false, true); // T3 = T2 * H'
    gsl_matrix_mul_elements(V, T3); // V = V .* T3

    // Now normalize columns of V to have unit sum
    for (uint j = 0; j < V->size2; j++) {
      colV = gsl_matrix_column(V, j);
      double sum = gsl_blas_dasum(&colV.vector);
      sum = 1.0/sum;
      gsl_blas_dscal(sum, &colV.vector);
    }

    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);

    sA->dotdiv(P, T2);		     // T2 = A ./ P
    T2->rmult(V, T4, false, true);   // T4 = V' * T2
    gsl_matrix_mul_elements(H, T4);  // H = H .* T4

    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
    
    if (iter % objmodulo == 0) {
      objold = objval;
      objval = sA->compute_obj(P, OBJ_KL);
      error  = fabs(objval - objold);
    }
    iter++;
  }

  std::cerr << "Final OBJ(A, VH) => " << objval << std::endl;
  gsl_matrix_free(P);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  //cerr << "Relative error => " << fnorm()/sA->fnorm() << endl;
  return 0;
}

double FactorizationSKL::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = sA->compute_obj(P, OBJ_KL);
  std::cerr << "FactorizationSKL::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
