// FactorizationBreg.cc -- Implements bregman updates
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#include <iostream>
#include <gsl/gsl_blas.h>
#include "FactorizationBreg.h"
#include "util.h"
#include <math.h>

// This function needs to be rewritten
int FactorizationBreg::doPerform()
{
  
  char *vhprefix = get_vh_file();
  if (vhprefix)
    init_matrices(Factorization::FRMFILE);
  else 
    init_matrices(Factorization::RANDOM);

  if (phizetaset < 3) {
    std::cerr << "BREG: Convex function \\phi and derivatives not given!\n";
    return -1;
  }

  gsl_matrix* P = gsl_matrix_alloc(A->size1, A->size2);
  double bnormA = 1; //bregDivFromZero(phi, psi, A);
  // Just for correct debugging messages we shd compute obj at least once
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  objval = compute_obj(P);

  gsl_matrix *T1, *T2, *T3, *T4, *T5, *T6;

  size_t m = A->size1;
  size_t n = A->size2;
  size_t k = V->size2;

  T1 = gsl_matrix_alloc(m, n);
  T2 = gsl_matrix_alloc(m, n);
  T3 = gsl_matrix_alloc(k, n);
  T4 = gsl_matrix_alloc(k, n);
  T5 = gsl_matrix_alloc(m, k);
  T6 = gsl_matrix_alloc(m, k);

  gsl_vector_view cola;			// Col of A
  cola = gsl_matrix_column(A, 0);
  
  gsl_vector_view c;
  gsl_vector* t1 = gsl_vector_alloc(m);
  gsl_vector* t2 = gsl_vector_alloc(m);

  gsl_vector_view x;
  gsl_vector_view y;

  //gsl_vector* vbtch = gsl_vector_alloc(m);
  
  int    iter = 0;
  double error = 1;
  double objold = objval+10;	// objval has at least some value

  double dabc;
  double dabd;

  while (error > epsilon() && iter < getMaxiter()) {
    if (iter % objmodulo == 0)
      DBG("BREG: " << iter << ": " << objval << std::endl);
    
    if (objval > objold) {
      std::cerr << "OBJVAL Increased!" << std::endl;
      relerror = objval/bnormA;
      //return 0;
    }

	x    = gsl_matrix_column(P, 0);
	dabc = bregDiv(&cola.vector, &x.vector);


	c = gsl_matrix_column(H, 0);
	gsl_blas_dgemv(CblasNoTrans, 1.0, V, &c.vector, 0.0, t1); // V*c

    // ITERATIVE UPDATE OF H
    SSUtil::apply(zeta, P, T1);		// T1 = zeta(P)
    gsl_matrix_memcpy(T2, T1);  // T2 = T1
    gsl_matrix_mul_elements(T2, P); // T2 = T2 . VH
    gsl_matrix_mul_elements(T1, A); // T1 = T1 . A
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, T1, 0.0, T3);    // V' T1
    gsl_blas_dgemm(CblasTrans, CblasNoTrans, 1.0, V, T2, 0.0, T4);    // V'T2


    // H = H .* T3 ./ T4
    gsl_matrix_mul_elements(H, T3);

    if (perturb)
      // Perturb T3 a little to prevent divide by zeros ...
      gsl_matrix_add_constant(T4, 10e-12);

    gsl_matrix_div_elements(H, T4);
    //if (iter == 0) printmat(H);

	c = gsl_matrix_column(H, 0);
	gsl_blas_dgemv(CblasNoTrans, 1.0, V, &c.vector, 0.0, t2); // V*chat

	double sum = 0.0;
	double sum2 = 0.0;

    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
	y = gsl_matrix_column(P, 0);

	dabd = bregDiv(&cola.vector, &y.vector);
	for (uint i = 0; i < t2->size; i++) {
	  double ai = gsl_vector_get(&cola.vector, i);
	  double bi = gsl_vector_get(t1, i);
	  double bhi  = gsl_vector_get(t2, i);
	  sum = sum + (ai - bhi)*(psi(bhi) - psi(bi));
	  sum2 += (ai - bhi)*(zeta(bi))*(bhi - bi);
	}
	/*std::cout << "SUM := " 
			  << sum 
			  << " SUM2 := " 
			  << sum2 
			  << " DEL := " 
			  << dabc - dabd 
			  << std::endl;
*/

    // Computations for iterative update of V

    SSUtil::apply(zeta, P, T1);				// T1 = zeta(P)
    gsl_matrix_memcpy(T2, T1);
    gsl_matrix_mul_elements(T1, A);
    gsl_matrix_mul_elements(T2, P);

    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, T1, H, 0.0, T5);
    gsl_blas_dgemm(CblasNoTrans, CblasTrans, 1.0, T2, H, 0.0, T6);

    // V = V .* T5 ./ T6
    gsl_matrix_mul_elements(V, T5);

    if (perturb)
      // Perturb T6 a little to prevent div. by zero..
      gsl_matrix_add_constant(T6, 10e-12);

    gsl_matrix_div_elements(V, T6);
    
    // Compute the objective function. ...
    // Earlier computation below was conditional but now since we have to
    // do VH each time around ...
    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);   // VH
    objold  = objval;
    objval = compute_obj(P);
    error = fabs(objval - objold);
    iter++;
  }

  gsl_matrix_free(T1);
  gsl_matrix_free(T2);
  gsl_matrix_free(T3);
  gsl_matrix_free(T4);
  gsl_matrix_free(T5);
  gsl_matrix_free(T6);

  relerror = objval/bnormA;	
  std::cerr << "D(A|| VH) => " << relerror << std::endl;

  return 0;
}


double FactorizationBreg::compute_obj(const gsl_matrix* prod)
{
  double obj = 0.0;
  double tmp;
  for (uint i = 0; i < A->size1 * A->size2; i++) {
    tmp = bregmanDivergence(phi, psi, A->data[i], prod->data[i]);
    // FOR DEBUGGING RITENOW
    if (tmp < 0) {
      std::cerr << "WARNING!!! DIVERGENCE NEGATIVE...." << tmp << "\n";
      tmp = 0;
    }
    obj += tmp;
  }
  return obj;
}

double FactorizationBreg::compute_objonly(char* vhprefix)
{
  set_vh_file(vhprefix);
  init_file();
  gsl_matrix* P = gsl_matrix_alloc(V->size1, H->size2);
  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, V, H, 0.0, P);
  double val = compute_obj(P);
  std::cerr << "FactorizationFrobWt1::objval ==> " << val << std::endl;
  gsl_matrix_free(P);
  return val;
}

double FactorizationBreg::bregDiv(const gsl_vector* a, const gsl_vector* x)
{
  double val = 0.0;
  for (uint i = 0; i < a->size; i++) {
	double ai = gsl_vector_get(a, i);
	double xi = gsl_vector_get(x, i);
	val += bregmanDivergence(phi, psi, ai, xi);
  }
  return val;
}

