// -*- c++ -*-
/* ************************************************************
 * File: Factorization.h
 * Author: Suvrit Sra
 * Revision History:
 * 02/25/03  Creation
 *
 * ************************************************************/

// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#ifndef FACTORIZATION_H
#define FACTORIZATION_H

#include <gsl/gsl_rng.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_blas.h>
#include <time.h>
#include <iostream>
#include <cstring>

#include "DriverOpts.h"

typedef unsigned int uint;

#if !defined(DBG)
#undef DBG

#define DBG(x) if (verbose) \
{\
  std::cout << x;\
}

#endif

class Factorization {
private:
  int rank;
  int maxiter;
  double eps;
protected:
  bool verbose;
  bool debug;
  bool normalize;
  bool perturb;
  DriverOpts* opts;
  gsl_matrix* A;
  gsl_matrix* V, *H;
  double objval;
  unsigned long int m_seed;
  int objmodulo;
  char* vhprefix;
  char* resfile;
  double relerror;		// Relative error;
  double svderror;

  enum tIniType {RANDOM, FRMFILE, OTHER};
public:
  /** This constructor is here so that one can derive a factoriaztion
   * class from this class without having to implement V, H as Matrix
   * objs.*/
  Factorization(DriverOpts* o = 0)
  {
    opts = o;
    rank = opts->rank;
    perturb = opts->perturb;
    maxiter = opts->maxiter;
    eps     = opts->epsilon;
    verbose = opts->verbose;
    A       = opts->A;
    V = H = 0;
    if (A) {
      V = gsl_matrix_alloc(A->size1, rank);
      H = gsl_matrix_alloc(rank, A->size2);
    }
    objval = 0.0;
    relerror = 100;
    debug  = false;
    m_seed = opts->rseed;
    objmodulo = opts->objmodulo;
    vhprefix = opts->vhprefix;
    normalize = opts->normalize;
    resfile   = opts->resfile;
  }

  
  Factorization(int r, int m, double e, bool v = false, unsigned long int s = 1, int om = 1, char* vh = 0, bool no = false) :
    rank(r), maxiter(m), eps(e), verbose(v)
  { objval = 0; A = 0; V = 0; H = 0; debug = false; m_seed = s; objmodulo = om; vhprefix = vh; normalize = no;}

  Factorization(int r, int m, double e, gsl_matrix* a, 
		bool v = false, unsigned long int s = 1, int om = 1, char* vh = 0, bool no = false) : 
    rank(r), maxiter(m), eps(e), verbose(v), A(a) {
    V = gsl_matrix_alloc(A->size1, r);
    H = gsl_matrix_alloc(r, A->size2);
    objval = 0.0;
    debug = false;
    m_seed = s;
    objmodulo = om;
    vhprefix = vh;
    normalize = no;
  }


  virtual ~Factorization() { 
    if (V != 0)
      gsl_matrix_free(V);
    if (H != 0)
      gsl_matrix_free(H);
    if (A != 0)
      gsl_matrix_free(A);
  }

  virtual int  perform() = 0;

  virtual void init_matrices(tIniType);
  virtual void init_random();
  virtual void init_file();

  void set_relerror(double val) { relerror = val;}
  double get_relerror() const { return relerror;}

  void set_svderror(double val) { svderror = val;}
  double get_svderror() const { return svderror; }

  void set_vh_file(char* name) {
    if (vhprefix == 0)
      vhprefix = strdup(name);
    else {
      free(vhprefix);
      vhprefix = strdup(name);
    }
  }

  char* get_vh_file() { return vhprefix;}

  void printmat(gsl_matrix*);
  Factorization() {}

  int getRank()    const { return rank;   }
  virtual bool checkRank() {
    return (getRank() <= std::min ((int)A->size1, (int)A->size2));
  }
  int getMaxiter() const { return maxiter;}
  double epsilon()  const { return eps;    }
  double fnorm(const gsl_matrix*);
  double fnorm(const gsl_matrix*, const gsl_matrix*);

  void normalizeVH();
  virtual void writeResults(char*, bool bin = false);
  virtual double compute_objonly(char*) = 0;

  double bregmanDivergence(double (*f)(double),  double (*y)(double), gsl_matrix* A, gsl_matrix*B);
  double bregmanDivergence(double (*f)(double),  double (*y)(double), double, double);
  double bregDivFromZero  (double (*f)(double),  double (*y)(double), gsl_matrix*);
  /**
   * Exception classes.
   */
  class FactMemoryAllocation {
  public:
    FactMemoryAllocation() {}
  };
  class FactExceptionInvalid { 
  public:
    FactExceptionInvalid(char *s) { 
      std::cerr << "Factorization exception: " << s << std::endl;
    } 
  };
};

#endif
