// FactorizationHybrid2.cc (mix of NNLS and FrobSlow)
// Copyright (C) 2003-2004 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationHybrid2.h"
#include "nnls.h"
#include <math.h>

// In this hybrid scheme first we do a couple of NNLS iterations. The control of these is available
// through the command line interface.
int FactorizationHybrid2::perform()
{
  std::cerr << "NNLSITER ==>\t" << nnlsiter << std::endl;

  P = gsl_matrix_alloc(V->size1, H->size2); // P is a member variable now.
  
  performnnls(0);
  double oldobj = objval;
  performfrob(1);

  while (fabs(oldobj-objval) > epsilon()) {
    performnnls(1);
    oldobj = objval;
    performfrob(1);
  }

  gsl_matrix_free(P);

  return 0;
}

// Just copied from FactorizatioNNLS::perform()
double FactorizationHybrid2::performnnls(int i)
{
  if (i == 0) {
    char *vhprefix = get_vh_file();
    if (vhprefix)
      init_matrices(Factorization::FRMFILE);
    else 
      init_matrices(Factorization::RANDOM);
  }
  // Initialize V, H and compute init objective value
  //init_matrices(Factorization::RANDOM);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);

  // Short hand for commonly used quantities.
  int m = A->size1;
  int r = getRank();
  int n = A->size2;

  int iter = 1;			// Iteration number
  double error = 100;		// Delta(objval)
  double objold = objval+10;;

  double* aj  = new double[m];	// Hold a col of A
  double* atj = new double[n];  // hold a col of A^t
  double* vtj = new double[r];  // hold a col of V'
  double rn;
  double* w   = new double[std::max(m,n)];
  double* zz  = new double[std::max(m,n)];
  int* idx    = new int   [std::max(m,n)];
  double* ht  = new double[r*n]; // H' holder
  double* cdv = new double[m*r]; // copy of dv

  while (error > epsilon() && iter < nnlsiter) {
    if (iter % objmodulo == 0)
      DBG("NNLS: " << iter << ": " << objval << std::endl);

    // Fix V and compute H so that we min ||VH - A||_F
    for (int j = 0; j < n; j++) {
      for (int i = 0; i < m; i++) {
	aj[i] = gsl_matrix_get(A, i,j);
      }
      int flg;
      // Now dv will get modified so let us make a copy or sth.
      
      // Here things get messy since V is row major, have to unlay it out in col major or nnls call
      int counter = 0;
      for (int k = 0; k < r; k++) {
	for (int i = 0; i < m; i++) {
	  cdv[counter++] = gsl_matrix_get(V, i, k);
	}
      }
      //for (int i = 0; i < m*r; i++)
      //cdv[i] = V->data[i];
      // For now hold result in ht and then trf. it to H_j column
      nnls(cdv, m, m, r, aj, ht, &rn, w, zz, idx, &flg);
      for (int i = 0; i < r; i++)
	gsl_matrix_set(H, i, j, ht[i]);

      //cerr << rn << " ";
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
    }
    //int k = 0;
    int counter = 0;
    for (int k = 0; k < r*n; k++)
      ht[counter++] = H->data[k];
    
    for (int j = 0; j < m; j++) {
      // Now jth column of A' is just jth row of A
      for (int i = 0; i < n; i++) {
	atj[i] = gsl_matrix_get(A, j,i);
      }
      
      int flg;
      nnls(ht, n, n, r, atj, vtj, &rn, w, zz, idx, &flg);
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
      counter = 0;
      for (int k = 0; k < r*n; k++)
      	ht[counter++] = H->data[k];
      
      // Copy into V'
      for (int i = 0; i < r; i++)
	gsl_matrix_set(V, j, i, vtj[i]);
    }

    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
      objold = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    if (objval > objold) {
      std::cerr << "NNlS: Objval increased\n";
    }
    iter++;
  }
  //DBG("NNLS: objval == " << objval);
  double retval = objval/ fnorm(A);
  relerror = retval;
  std::cerr << "NNLS: Complete. Relative error ==> " << retval << std::endl;

  // Let us free some memory!!!
  delete[] aj;
  delete[] atj;
  delete[] vtj;
  delete[] w;
  delete[] zz;
  delete[] idx;
  delete[] ht;
  delete[] cdv;
  return retval;
}


// THis function is called after performnnls so this does not need to do any init. of V, H.
double FactorizationHybrid2::performfrob(int i)
{
  
  std::cerr << "FactorizationHybrid2::performfrob(" << i << ")\n";
  
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);
  
  //  std::cerr << "OBJVALUE at entry == " << objval << std::endl;
  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  gsl_vector*     t1;
  gsl_vector*     t2;
  gsl_vector*     t3;
  gsl_vector*     t4;

  gsl_vector_view rowaH;
  gsl_vector_view colaV;


  t1 = gsl_vector_alloc(V->size1);
  t2 = gsl_vector_alloc(H->size2);
  t3 = gsl_vector_alloc(V->size1);
  t4 = gsl_vector_alloc(V->size1);

  gsl_vector* temp  = gsl_vector_alloc(H->size1);
  gsl_vector* temp2 = gsl_vector_alloc(V->size2);
  gsl_vector* t1p   = gsl_vector_alloc(H->size2);

  while (error > epsilon() && iter < frobiter) {
    if (iter % objmodulo == 0)
      DBG("FROB2: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "OBJVAL Increased! Quitting! Bug in code" << std::endl;
      relerror = objval/fnorm(A);
      //return 0;
    }

    // Do one row/col at a time updates...
    for (size_t a = 0; a < V->size2; a++) {
      // Update H
      gsl_matrix_get_col(t1, V, a); // t1 = Col(a, V)

      gsl_blas_dgemv(CblasTrans, 1.0, V, t1, 0.0, temp); // temp = V' t1
      gsl_blas_dgemv(CblasTrans, 1.0, H, temp, 0.0, t2); // t2 = temp*H
      gsl_blas_dgemv(CblasTrans, 1.0, A, t1, 0.0, t1p);   // t1p = t1*A

      rowaH = gsl_matrix_row(H, a);
      gsl_vector_mul(&rowaH.vector, t1p);
      
      if (perturb)
	gsl_vector_add_constant(t2, 1e-12);
      gsl_vector_div(&rowaH.vector, t2);

      // Update V
      gsl_blas_dgemv(CblasNoTrans, 1.0, A, &rowaH.vector, 0.0, t3); //t3 = A * rowaH.vector;
      gsl_blas_dgemv(CblasNoTrans, 1.0, H, &rowaH.vector, 0.0, temp2); //temp2 = H * rowaH.vector;

      gsl_blas_dgemv(CblasNoTrans, 1.0, V, temp2, 0.0, t4); //t4 = V * temp2;

      colaV = gsl_matrix_column(V, a);
      
      gsl_vector_mul(&colaV.vector, t3);

      if (perturb)
	gsl_vector_add_constant(t4, 1e-12);

      gsl_vector_div(&colaV.vector, t4);
    }

    // Compute the objective function.
    // NOTE: Have to make this block conditional so that we can save obj function
    // computations.
    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
      objold  = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    iter++;
  }

  gsl_vector_free(t1);
  gsl_vector_free(t1p);
  gsl_vector_free(t2);
  gsl_vector_free(t3);
  gsl_vector_free(t4);
  gsl_vector_free(temp);
  gsl_vector_free(temp2);

  relerror = objval/fnorm(A);	// In super class...
  std::cerr << "Hybrid2::performfrob(): Relative error => " << relerror << std::endl;

  return relerror;
}

// Hope now it gets computed faster...
double FactorizationHybrid2::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double diff = A->data[i] - prod->data[i];
    obj += (diff*diff);
  }
  return sqrt(obj);
}

double FactorizationHybrid2::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrob::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
