// FactorizationHybrid.cc (mix of NNLS and Frob)
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationHybrid.h"
#include "nnls.h"
#include <math.h>

// In this hybrid scheme first we do a couple of NNLS iterations. The control of these is available
// through the command line interface.
int FactorizationHybrid::perform()
{
  std::cerr << "NNLSITER ==>\t" << nnlsiter << std::endl;

  P = gsl_matrix_alloc(V->size1, H->size2); // P is a member variable now.

  performnnls(0);
  double oldobj = objval;
  performfrob(1);

  while (fabs(oldobj-objval) > epsilon()) {
    performnnls(1);
    oldobj = objval;
    performfrob(1);
  }

  gsl_matrix_free(P);

  return 0;
}

// Just copied from FactorizatioNNLS::perform()
double FactorizationHybrid::performnnls(int i)
{
  if (i == 0) {
    char *vhprefix = get_vh_file();
    if (vhprefix)
      init_matrices(Factorization::FRMFILE);
    else 
      init_matrices(Factorization::RANDOM);
  }
  // Initialize V, H and compute init objective value
  //init_matrices(Factorization::RANDOM);
  
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);

  // Short hand for commonly used quantities.
  int m = A->size1;
  int r = getRank();
  int n = A->size2;

  int iter = 1;			// Iteration number
  double error = 100;		// Delta(objval)
  double objold = objval+10;;

  double* aj  = new double[m];	// Hold a col of A
  double* atj = new double[n];  // hold a col of A^t
  double* vtj = new double[r];  // hold a col of V'
  double rn;
  double* w   = new double[std::max(m,n)];
  double* zz  = new double[std::max(m,n)];
  int* idx    = new int   [std::max(m,n)];
  double* ht  = new double[r*n]; // H' holder
  double* cdv = new double[m*r]; // copy of dv

  while (error > epsilon() && iter < nnlsiter) {
    if (iter % objmodulo == 0)
      DBG("NNLS: " << iter << ": " << objval << std::endl);

    // Fix V and compute H so that we min ||VH - A||_F
    for (int j = 0; j < n; j++) {
      for (int i = 0; i < m; i++) {
	aj[i] = gsl_matrix_get(A, i,j);
      }
      int flg;
      // Now dv will get modified so let us make a copy or sth.
      
      // Here things get messy since V is row major, have to unlay it out in col major or nnls call
      int counter = 0;
      for (int k = 0; k < r; k++) {
	for (int i = 0; i < m; i++) {
	  cdv[counter++] = gsl_matrix_get(V, i, k);
	}
      }
      //for (int i = 0; i < m*r; i++)
      //cdv[i] = V->data[i];
      // For now hold result in ht and then trf. it to H_j column
      nnls(cdv, m, m, r, aj, ht, &rn, w, zz, idx, &flg);
      for (int i = 0; i < r; i++)
	gsl_matrix_set(H, i, j, ht[i]);

      //cerr << rn << " ";
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
    }
    //int k = 0;
    int counter = 0;
    for (int k = 0; k < r*n; k++)
      ht[counter++] = H->data[k];
    
    for (int j = 0; j < m; j++) {
      // Now jth column of A' is just jth row of A
      for (int i = 0; i < n; i++) {
	atj[i] = gsl_matrix_get(A, j,i);
      }
      
      int flg;
      nnls(ht, n, n, r, atj, vtj, &rn, w, zz, idx, &flg);
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
      counter = 0;
      for (int k = 0; k < r*n; k++)
      	ht[counter++] = H->data[k];
      
      // Copy into V'
      for (int i = 0; i < r; i++)
	gsl_matrix_set(V, j, i, vtj[i]);
    }

    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
      objold = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    if (objval > objold) {
      std::cerr << "NNlS: Objval increased\n";
    }
    iter++;
  }
  //DBG("NNLS: objval == " << objval);
  double retval = objval/ fnorm(A);
  relerror = retval;
  std::cerr << "NNLS: Complete. Relative error ==> " << retval << std::endl;

  // Let us free some memory!!!
  delete[] aj;
  delete[] atj;
  delete[] vtj;
  delete[] w;
  delete[] zz;
  delete[] idx;
  delete[] ht;
  delete[] cdv;
  return retval;
}


// THis function is called after performnnls so this does not need to do any init. of V, H.
double FactorizationHybrid::performfrob(int i)
{
  
  std::cerr << "FactorizationHybrid::performfrob(" << i << ")\n";
  gsl_matrix* T1, *T2, *T3, *T4, *T5, *T6;
  
  T1 = gsl_matrix_alloc(V->size2, H->size2);
  T2 = gsl_matrix_alloc(V->size2, V->size2);
  T3 = gsl_matrix_alloc(V->size2, H->size2);
  T4 = gsl_matrix_alloc(V->size1, H->size1);
  T5 = gsl_matrix_alloc(H->size1, H->size1);
  T6 = gsl_matrix_alloc(V->size1, H->size1);

  // Just for correct debugging messages we shd compute obj at least once
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);
  
  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  while (error > epsilon() && iter < frobiter) {
    if (iter % objmodulo == 0)
      DBG("FROB: " << iter << ": " << objval << std::endl);
    
    if (objval > objold && verbose)  {
      std::cerr << "WARNING: OBJVAL Increased!" << std::endl;
      //return -1;
    }

    
    // Computations for iterative update of H
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, A, 0.0, T1);    // V'A
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, V, 0.0, T2);    // V'V
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T2, H, 0.0, T3); // (V'V)H

    // H = H .* T1 ./ T3
    gsl_matrix_mul_elements(H, T1);
    if (perturb)
      // Perturb T3 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T3, 10e-12);
    gsl_matrix_div_elements(H, T3);
    //if (iter == 0) printmat(H);

    // Computations for iterative update of V
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, A, H, 0.0, T4);    // AH'
    //if (iter == 0) printmat(T4);
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, H, H, 0.0, T5);    // HH'
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, T5, 0.0, T6); // V(HH')

    // V = V .* T4 ./ T6
    gsl_matrix_mul_elements(V, T4);
    if (perturb)
      // Perturb T6 a little to prevent div. by zero..
      gsl_matrix_add_constant(T6, 10e-12);
    gsl_matrix_div_elements(V, T6);
    
    // Compute the objective function.
    // NOTE: Have to make this block conditional so that we can save obj function
    // computations.
    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
      objold  = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    iter++;
  }

  gsl_matrix_free(T1);
  gsl_matrix_free(T2);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  gsl_matrix_free(T5);
  gsl_matrix_free(T6);
  //  DBG("Objval = " << objval << std::endl);
  double retval = objval/fnorm(A);
  relerror = retval;
  std::cerr << "Relative error => " << retval << std::endl;
  return retval;
}


// Hope now it gets computed faster...
double FactorizationHybrid::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double diff = A->data[i] - prod->data[i];
    obj += (diff*diff);
  }
  return sqrt(obj);
}

double FactorizationHybrid::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrob::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
