// FactorizationNNLS.cc -- ALS algorithm for min. frob. error
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include "nnls.h"
#include "FactorizationNNLS.h"
#include <fstream>

// This program is highly underoptimized.
int FactorizationNNLS::perform()
{

  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  // Initialize V, H and compute init objective value
  //init_matrices(Factorization::RANDOM);
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);

  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);

  // Short hand for commonly used quantities.
  int m = A->size1;
  int r = getRank();
  int n = A->size2;

  int iter = 1;			// Iteration number
  double error = 100;		// Delta(objval)
  double objold = objval+10;

  double* aj  = new double[m];	// Hold a col of A
  double* atj = new double[n];  // hold a col of A^t
  double* vtj = new double[r];  // hold a col of V'
  double rn;
  double* w   = new double[std::max(m,n)];
  double* zz  = new double[std::max(m,n)];
  int* idx    = new int[std::max(m,n)];
  double* ht  = new double[r*n]; // H' holder
  double* cdv = new double[m*r]; // copy of dv

  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("NNLS: " << iter << ": " << objval << std::endl);

    // Fix V and compute H so that we min ||VH - A||_F
    for (int j = 0; j < n; j++) {
      for (int i = 0; i < m; i++) {
	aj[i] = gsl_matrix_get(A, i,j);
      }
      int flg;
      // Now dv will get modified so let us make a copy or sth.
      
      // Here things get messy since V is row major, have to unlay it out in col major or nnls call
      int counter = 0;
      for (int k = 0; k < r; k++) {
	for (int i = 0; i < m; i++) {
	  cdv[counter++] = gsl_matrix_get(V, i, k);
	}
      }
      //for (int i = 0; i < m*r; i++)
      //cdv[i] = V->data[i];
      // For now hold result in ht and then trf. it to H_j column
      nnls(cdv, m, m, r, aj, ht, &rn, w, zz, idx, &flg);
      for (int i = 0; i < r; i++)
	gsl_matrix_set(H, i, j, ht[i]);

      //cerr << rn << " ";
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
    }
    //print_all_matrix(myH2);
    //cerr << "Now going to find V'\n";
    // Now we fix H' and find opt. V' s.t. ||H'V'-A'||_f is min.
    // First find col majorized H'
    //cerr << "Going to compute optimal V" << endl;
    /* OLD CODE */
    //int k = 0;

    //for (int i = 0; i < n; i++)
    //for (int j = 0; j < r; j++)
    //ht[k++] = gsl_matrix_get(H, i, j);
    
    // H' in column major order is just H in row major order!
    int counter = 0;
    for (int k = 0; k < r*n; k++)
      ht[counter++] = H->data[k];
    
    for (int j = 0; j < m; j++) {

      // Now jth column of A' is just jth row of A
      for (int i = 0; i < n; i++) {
	atj[i] = gsl_matrix_get(A, j,i);
      }
      
      int flg;
      nnls(ht, n, n, r, atj, vtj, &rn, w, zz, idx, &flg);
      //cerr << rn << " ";
      if (flg != 1) {
	std::cerr << "NNLS: nnls failed to find solution\n";
      }
      
      // Now ht is modified so we need to recopy it. This sux as it
      // leads to so much extra copying
      
      //k = 0;
      //for (int a = 0; a < r; a++)
      //	for (int b = 0; b < n; b++)
      //	  ht[k++] = H->data[b*r+a];
      
      counter = 0;
      for (int k = 0; k < r*n; k++)
      	ht[counter++] = H->data[k];
      
      // Copy into V'
      for (int i = 0; i < r; i++)
	//V->data[i*m+j] = vtj[i];
	gsl_matrix_set(V, j, i, vtj[i]);
    }

    if (iter % objmodulo == 0) {
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
      objold = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    if (objval > objold) {
      std::cerr << "NNlS: Objval increased\n";
    }
    iter++;
  }
  gsl_matrix_free(P);
  relerror = objval/fnorm(A);
  std::cerr << "NNLS: Complete. Relative error ==> " << relerror << std::endl;
  return 0;
}

double FactorizationNNLS::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double diff = A->data[i] - prod->data[i];
    obj += (diff*diff);
  }
  return sqrt(obj);
}


double FactorizationNNLS::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationNNLS::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
