// FactorizationFrobWt1.cc -- Implements elementwise weighted LS type algo
// Copyright (C) 2004 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationFrobWt1.h"
#include "util.h"
#include <math.h>

// This function needs to be rewritten
int FactorizationFrobWt1::perform()
{
  
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  if (wfile == 0) {
    std::cerr << "FROBWT1: Weight matrix not set, quitting\n";
    return -1;
  }

  // read in the weight matrix...usinga util.h function
  W = SSUtil::read_gsl_matrix(wfile);

  if (W == 0) {
    return -1;
  }


  // RIGHT NOW THIS PROGRAM IS UNDEROPTIMIZED COZ I CAN'T TAKE ADVANTAGE OF
  // THE MATRIX MULT ORDER COZ OF THE HADAMARD PRODUCT FALLING IN BETWEEN.

  // Compute weighted fnorm of A
  double normA = fnorm(A, W);
  // W <- W . W
  gsl_matrix_mul_elements(W, W);

  // A <- A . W
  gsl_matrix_mul_elements(A, W);

  //init_matrices(Factorization::FRMFILE);
  gsl_matrix* T1, *T3, *T4,  *T6;
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  
  T1 = gsl_matrix_alloc(V->size2, H->size2);
  //  T2 = gsl_matrix_alloc(V->size2, V->size2);
  T3 = gsl_matrix_alloc(V->size2, H->size2);
  T4 = gsl_matrix_alloc(V->size1, H->size1);
  //  T5 = gsl_matrix_alloc(H->size1, H->size1);
  T6 = gsl_matrix_alloc(V->size1, H->size1);


  // Just for correct debugging messages we shd compute obj at least once
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);
  

  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("FROBWT1: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "OBJVAL Increased!" << std::endl;
      relerror = objval/normA;
      //return 0;
    }


    // Computations for iterative update of H
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, A, 0.0, T1);    // V'(A.W)

    gsl_matrix_mul_elements(P, W); // VH . W . W
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, P, 0.0, T3);    // V' P
    //gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, T2, H, 0.0, T3); // (V'V)H

    // H = H .* T1 ./ T3
    gsl_matrix_mul_elements(H, T1);

    if (perturb)
      // Perturb T3 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T3, 10e-12);

    gsl_matrix_div_elements(H, T3);
    //if (iter == 0) printmat(H);

    // Computations for iterative update of V
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, A, H, 0.0, T4);    // AH'
    //if (iter == 0) printmat(T4);
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);  // VH
    gsl_matrix_mul_elements(P, W); // VH . W . W
    //gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, H, H, 0.0, T5);    // HH'
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, P, H, 0.0, T6); // (VH.W.W)H')

    // V = V .* T4 ./ T6
    gsl_matrix_mul_elements(V, T4);

    if (perturb)
      // Perturb T6 a little to prevent div. by zero..
      gsl_matrix_add_constant(T6, 10e-12);

    gsl_matrix_div_elements(V, T6);
    
    // Compute the objective function. ...
    // Earlier computation below was conditional but now since we have to
    // do VH each time around ...
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
    objold  = objval;
    objval = compute_obj(P);
    error = fabs(objval - objold);
    iter++;
  }

  gsl_matrix_free(T1);
  //  gsl_matrix_free(T2);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  //  gsl_matrix_free(T5);
  gsl_matrix_free(T6);
  //printmat(P);
  relerror = objval/normA;	// In super class...
  std::cerr << "||A . W - V H . W|| / || A . W || ==> " << relerror << std::endl;

  return 0;
}

// Hope now it gets computed faster...
// The computation is done taking care of the fact that A is right 
// now A . W . W .. basically what we want to do is:
// (a - p) (a - p)*w*w, but 'a' is actually 'a*w*w' so we can do this
// computation by doing (a - p*w*w) ... we then need to div. out by w*w
// since we square things.
double FactorizationFrobWt1::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    double diff = A->data[i] - (prod->data[i]*W->data[i]);
    obj += (diff*diff) / W->data[i]; // W is already W . W so...
  }
  return sqrt(obj);
}

double FactorizationFrobWt1::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrobWt1::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}
