// FactorizationFrobSlow.cc -- Implements lee.seung frob. algo
// BLAS1, BLAS2 level updates instead of blas3 in factorizationfrob.cc
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_matrix.h>

#include <gsl/gsl_blas.h>
#include "FactorizationFrobSlow.h"
#include <math.h>

// This function needs to be rewritten
int FactorizationFrobSlow::perform()
{
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);


  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);

  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);
  
  
  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  gsl_vector*     t1;
  gsl_vector*     t2;
  gsl_vector*     t3;
  gsl_vector*     t4;

  gsl_vector_view rowaH;
  gsl_vector_view colaV;


  t1 = gsl_vector_alloc(V->size1);
  t2 = gsl_vector_alloc(H->size2);
  t3 = gsl_vector_alloc(V->size1);
  t4 = gsl_vector_alloc(V->size1);

  gsl_vector* temp  = gsl_vector_alloc(H->size1);
  gsl_vector* temp2 = gsl_vector_alloc(V->size2);
  gsl_vector* t1p   = gsl_vector_alloc(H->size2);

  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("FROB: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "OBJVAL Increased! Quitting! Bug in code" << std::endl;
      relerror = objval/fnorm(A);
      //return 0;
    }

    // Do one row/col at a time updates...
    for (size_t a = 0; a < V->size2; a++) {
      // Update H
      gsl_matrix_get_col(t1, V, a); // t1 = Col(a, V)

      gsl_blas_dgemv(CblasTrans, 1.0, V, t1, 0.0, temp); // temp = V' t1
      gsl_blas_dgemv(CblasTrans, 1.0, H, temp, 0.0, t2); // t2 = temp*H
      gsl_blas_dgemv(CblasTrans, 1.0, A, t1, 0.0, t1p);   // t1p = t1*A

      rowaH = gsl_matrix_row(H, a);
      gsl_vector_mul(&rowaH.vector, t1p);
      
      if (perturb)
	gsl_vector_add_constant(t2, 1e-12);
      gsl_vector_div(&rowaH.vector, t2);

      // Update V
      gsl_blas_dgemv(CblasNoTrans, 1.0, A, &rowaH.vector, 0.0, t3); //t3 = A * rowaH.vector;
      gsl_blas_dgemv(CblasNoTrans, 1.0, H, &rowaH.vector, 0.0, temp2); //temp2 = H * rowaH.vector;

      gsl_blas_dgemv(CblasNoTrans, 1.0, V, temp2, 0.0, t4); //t4 = V * temp2;

      colaV = gsl_matrix_column(V, a);
      
      gsl_vector_mul(&colaV.vector, t3);

      if (perturb)
	gsl_vector_add_constant(&colaV.vector, 1e-12);

      gsl_vector_div(&colaV.vector, t4);
    }

    // Compute the objective function.
    // NOTE: Have to make this block conditional so that we can save obj function
    // computations.
    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
      objold  = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    iter++;
  }

  gsl_vector_free(t1);
  gsl_vector_free(t1p);
  gsl_vector_free(t2);
  gsl_vector_free(t3);
  gsl_vector_free(t4);
  gsl_vector_free(temp);
  gsl_vector_free(temp2);

  relerror = objval/fnorm(A);	// In super class...
  std::cerr << "Relative error => " << relerror << std::endl;

  return 0;
}



// Hope now it gets computed faster...
double FactorizationFrobSlow::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double diff = A->data[i] - prod->data[i];
    obj += (diff*diff);
  }
  return sqrt(obj);
}

double FactorizationFrobSlow::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrobSlow::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
