// File: FactorizationSFrob.cc -- lee.seung frob for sparse
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationSFrob.h"
#include "objective.h"

#include <math.h>

// This function needs to be rewritten
int FactorizationSFrob::perform()
{
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  //init_matrices(Factorization::RANDOM);	// uses seed in m_seed

  gsl_matrix* T1, *T2, *T3, *T4, *T5, *T6;
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);


  T1 = gsl_matrix_alloc(V->size2, H->size2);
  T2 = gsl_matrix_alloc(V->size2, V->size2);
  T3 = gsl_matrix_alloc(V->size2, H->size2);
  T4 = gsl_matrix_alloc(V->size1, H->size1);
  T5 = gsl_matrix_alloc(H->size1, H->size1);
  T6 = gsl_matrix_alloc(V->size1, H->size1);

  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  //printmat(P);

  std::cerr << "Computing Initial objective value (slow....)\n";
  objval = sA->compute_obj(P, OBJ_FROB);

  // The initial objective function might be reported to be too small
  // because V * H does not have same sparsity pattern as A, to be more
  // accurate, I could try to esp. report objval for the first time around
  // ... 
  int    iter   = 0;
  double error  = 100;
  double objold = objval+1;

  // With small 'r' the lee/seung algo and in fact even NNLS give an
  // increase in the objval after the first iteration...how is that
  // possible? 
  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("FROB: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "WARNING: OBJVAL Increased! Possible Bug in code" << std::endl;
    }

    // Computations for updating H
    sA->rmult(V, T1, false, true); // T1 = V'A

    if (debug) {
      printmat(V);
      sA->printmat(std::cerr);
      printmat(T1);
    }
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, V, 0.0, T2);    // T2 = V'V
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T2, H, 0.0, T3); // T3 = (V'V)H

    // H = H .* T1 ./ T3
    gsl_matrix_mul_elements(H, T1);

    if (perturb)
      // Perturb T3 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T3, 10e-12);

    gsl_matrix_div_elements(H, T3);

    //if (iter == 0) printmat(H);
    // Computations for updating V
    sA->lmult(H, T4, false, true);  // T4 = AH'
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, H, H, 0.0, T5);    // T5 = HH'
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, T5, 0.0, T6); // T6 = V(HH')

    // V = V .* T4 ./ T6
    gsl_matrix_mul_elements(V, T4);
    if (perturb)
      // Perturb T6 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T6, 10e-12);

    gsl_matrix_div_elements(V, T6);


    // Recompute the objective function
    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
      objold  = objval;
      objval = sA->compute_obj(P, OBJ_FROB);
      error = fabs(objval - objold);
    }

    iter++;
  }
  
  if (normalize) {
    std::cerr << "Normalizing V, H\n";
    normalizeVH();
  }

  gsl_matrix_free(T1);
  gsl_matrix_free(T2);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  gsl_matrix_free(T5);
  gsl_matrix_free(T6);
  gsl_matrix_free(P);
  //printmat(P);
  relerror = objval/ sA->fnorm();
  std::cerr << "FactorizationSFrob::perform(): Relative error => " << objval/sA->fnorm() << std::endl;
  return 0;
}

double FactorizationSFrob::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = sA->compute_obj(P, OBJ_FROB);
  std::cerr << "FactorizationSFrob::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
