// FactorizationFrobWt2.cc -- Implements weighted LS type algo
// This file is a little tricky to implement correctly...
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationFrobWt2.h"
#include "util.h"
#include <cmath>


// This function needs to be rewritten
int FactorizationFrobWt2::perform()
{
  
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  if (lfile == 0 && rfile == 0) {
    std::cerr << "FROBWT2: Neither Weight matrix is set, Use -a 0 instead of -a 8\n";
    return -1;
  }

  // read in the weight matrix...using a util.h function
  if (lfile) {
    std::cout << "FROBWT2: Reading in L ==> " << lfile << std::endl;
    L = SSUtil::read_gsl_matrix(lfile);
    if (L == 0) {
      return -1;
    }
    if (L->size2 != A->size1)
      return -1;
    // LtL <- L'L
    LtL = gsl_matrix_alloc(L->size2, L->size2);
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, L, L, 0.0, LtL);
  } 

  if (rfile) {
    std::cout << "FROBWT2: Reading in R ==> " << rfile << std::endl;
    R = SSUtil::read_gsl_matrix(rfile);
    if (R == 0) {
      return -1;
    }
    if (R->size1 != H->size2)
      return -1;
    // RRt <- RR'
    RRt = gsl_matrix_alloc(R->size1, R->size1);
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, R, R, 0.0, RRt);
  } 
  
  //init_matrices(Factorization::FRMFILE);
  gsl_matrix* T1, *T2, *T3, *T4, *T5, *T6;
  gsl_matrix* T7, *T8, *T9;
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_matrix* ARRt;

  size_t m  = A->size1;
  size_t n  = A->size2;
  size_t k =  V->size2;

  if (lfile)
    T1 = gsl_matrix_alloc(k, m);

  if (rfile)
    T2 = gsl_matrix_alloc(k, n);

  T3 = gsl_matrix_alloc(k, k);
  T4 = gsl_matrix_alloc(k, n);
  T5 = gsl_matrix_alloc(k, n);
  T6 = gsl_matrix_alloc(m, k);
  if (rfile)
    T7 = gsl_matrix_alloc(n, k);

  T8 = gsl_matrix_alloc(k, k);
  T9 = gsl_matrix_alloc(m, k);

  gsl_matrix* lV, *Hr, *lA;

  lV = 0; Hr = 0; lA = 0;

  if (lfile or rfile) {
    Ah = gsl_matrix_alloc(m, n);
    gsl_matrix_memcpy(Ah, A);
  }

  if (lfile) {
    lV = gsl_matrix_alloc(m, k);
    lA = gsl_matrix_alloc(m, n);
  }

  if (rfile) {
    Hr = gsl_matrix_alloc(k, n);
    ARRt = gsl_matrix_alloc(m, n);
  }


  // Just for correct debugging messages we shd compute obj at least once
  // We need to do L VH R
  if (lfile) 
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, L, V, 0.0, lV);
  else 
    lV = V;

  if (rfile)
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, H, R, 0.0, Hr);
  else
    Hr = H;

  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, lV, Hr, 0.0, P);

  if (lfile)
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, L, A, 0.0, lA);
  else
    lA = A;

  if (rfile)
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, lA, R, 0.0, Ah);
  else {
    if (Ah)
      gsl_matrix_free(Ah);
    Ah = lA;
  }

  objval = compute_obj(P);
  
  // ARRt <- ARR'
  if (rfile)
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, A, RRt, 0.0, ARRt);  
  else
    ARRt = A;

  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("FROBWT2: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "OBJVAL Increased! " << std::endl;
      relerror = objval/fnorm(A);
      //return 0;
    }

    // Do update for H
    if (lfile) 
      // V'L'L -- k x m
      gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, LtL, 0.0, T1); 
    else
      T1 = V;
    
    if (rfile)
      // HRR'  -- k x n
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, H, RRt, 0.0, T2);
    else
      T2 = H;

    if (lfile)
      // V'L'LV -- k x k
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T1, V, 0.0, T3);
    else
      gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, T1, V, 0.0, T3);

    if (lfile)
      // V'L'LARR' -- k x n
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T1, ARRt, 0.0, T4);
    else 
      // V'ARR' -- k x n
      gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, ARRt, 0.0, T4);

    // V'L'LVHRR' -- k x n
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T3, T2, 0.0, T5);

    // H = H .* T4 ./ T5
    gsl_matrix_mul_elements(H, T4);

    if (perturb)
      // Perturb T3 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T5, 10e-12);
    
    gsl_matrix_div_elements(H, T5);

    if (normalize)
      normalizeVH();

    // Computations for iterative update of V
    // ARR'H' -- m x k
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, ARRt, H, 0.0, T6);

    if (rfile)
      // RR'H' -- n x k
      gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, RRt, H, 0.0, T7); 
    else
      T7 = H;

    if (rfile)
      // HRR'H' -- k x k
      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, H, T7, 0.0, T8);
    else
      // HH'  -- k x k
      gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, H, H, 0.0, T8);
    
    // VHRR'H' -- m x k
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, T8, 0.0, T9);

    // V = V .* T6 ./ T9
    gsl_matrix_mul_elements(V, T6);

    if (perturb)
      // Perturb T6 a little to prevent div. by zero..
      gsl_matrix_add_constant(T9, 10e-12);

    gsl_matrix_div_elements(V, T9);
    
    // Compute the objective function.
    // NOTE: Have to make this block conditional so that we can save obj function
    // computations.
    if (iter % objmodulo == 0) {
      if (lfile)
	gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, L, V, 0.0, lV);
      else
	lV = V;

      if (rfile)
	gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, H, R, 0.0, Hr);
      else
	Hr = H;

      gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, lV, Hr, 0.0, P);

      objold  = objval;
      objval = compute_obj(P);
      error = fabs(objval - objold);
    }
    iter++;
  }

  if (lfile)
  gsl_matrix_free(T1);

  if (rfile)
    gsl_matrix_free(T2);

  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  gsl_matrix_free(T5);
  gsl_matrix_free(T6);

  if (rfile)
    gsl_matrix_free(T7);

  gsl_matrix_free(T8);
  gsl_matrix_free(T9);

  if (lV != V)
    gsl_matrix_free(lV);

  if (Hr != H)
    gsl_matrix_free(Hr);

  if (Ah != A and Ah != lA)
    gsl_matrix_free(Ah);
  
  if (lA != A)
    gsl_matrix_free(lA);
  
  gsl_matrix_free(P);
  relerror = objval/fnorm(Ah);	// In super class...
  std::cerr << "|L (A - VH) R|/ |LAR| " << relerror << std::endl;

  return 0;
}

/*
 * FIXME:
 */
double FactorizationFrobWt2::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < Ah->size1 * Ah->size2; i++) {
    double diff = Ah->data[i] - prod->data[i];
    obj += (diff*diff);
  }
  return sqrt(obj);
}

/*
 * TBD: The following function needs to incoporate the weight matrices
 */
double FactorizationFrobWt2::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrobWt1::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
