/* File: dense_matrix.h        -*- c++ -*-         */
// Author: Suvrit Sra
// Time-stamp: <08 July 2007 05:14:48 PM CDT --  suvrit>
// Wrapper around gsl_matrix to enable ease of use.


#ifndef _MY_DENSEMATRIX_H
#define _MY_DENSEMATRIX_H

#include <gsl/gsl_matrix.h>
#include <gsl/gsl_blas.h>
#include "matrix.h"
#include "exceptions.h"

namespace SSLib {

  /**
   * @class dense_matrix Just putting this in here to allow a more complete
   * solution, where all types of matrices are just subclasses of the
   * top-level matrix class. This particular class is just a wrapper around
   * a gsl_matrix class, to be able to inherit the complete power of the GSL.
   */
  class dense_matrix : public matrix {
    gsl_matrix* M;
    gsl_vector_view row;
    gsl_vector_view col;
  public:
    dense_matrix()  : matrix(0, 0)
    {
      M = 0;
    }
    dense_matrix(size_t rows, size_t cols, bool zeros = false)
      : matrix(rows, cols) {
      if (zeros)
        M = gsl_matrix_calloc(rows, cols);
      else 
        M = gsl_matrix_alloc(rows, cols);
      if (M == 0) {
        throw matrix_alloc("Memory allocation error in dense_matrix()"); 
      }
    }

    dense_matrix(gsl_matrix* m, size_t rows, size_t cols) : matrix(rows, cols) {
      M = m;
    }

    ~dense_matrix()
    {
      //free_data();
    }
    virtual void free_data() { gsl_matrix_free(M); M = 0;}
    virtual int load(const char*);
    virtual int load(const char*, bool);
    
    virtual void error(std::string s);

    /// Write out matrix in a specified way
    virtual int    save(const char*, bool asbin, int typ);

    /// Write out matrix to disk in CCS, CRS, etc. formats
    virtual int save(const char*, bool = false);
    virtual int save_as_ccs(const char*, bool = false);
    virtual int save_as_crs(const char*, bool = false);
    virtual int save_as_coord(const char*, bool = false);
    virtual int save_as_rowcoord(const char*, bool = false);
    virtual int save_as_colcoord(const char*, bool = false);
    virtual int save_as_matlab(const char*, bool = false);
    virtual int save_as_hb(const char*, bool = false);
    virtual int save_as_dense(const char*, bool = false);
    
  private:
    // Some helper functions
    int save_ccs_binary(const char*);
    int save_ccs_text  (const char*);
    int save_crs_binary(const char*);
    int save_crs_text  (const char*);
  public:
    double operator()(size_t i, size_t j) {
      return gsl_matrix_get(M, i, j);
    }

    double get(size_t i, size_t j) {
      return gsl_matrix_get(M, i, j);
    }

    int set(size_t i, size_t j, double v) {
      gsl_matrix_set(M, i, j, v); return 0;
    }
    
    // Am exposing the data to enable speed in some algorithms. One has
    // to sacrifice cleanliness of code for speed at times.
    gsl_matrix* get_matrix() {
      return M;
    }
    
    int get_row (size_t r, vector*& row) {
      gsl_vector_view ro = gsl_matrix_row(M, r);
      row = &ro.vector;
      return 0;
    }

    int get_row_copy(size_t r, vector* row) {
      gsl_vector_view ro = gsl_matrix_row(M, r);
      gsl_vector_memcpy(row, &ro.vector);
      return 0;
    }

    int get_col (size_t c, vector*& col) {
      gsl_vector_view co = gsl_matrix_column(M, c);
      col = &co.vector;
      return 0;
    }
    
    int get_col_copy(size_t c, vector* col) {
      gsl_vector_view co = gsl_matrix_column(M, c);
      gsl_vector_memcpy(col, &co.vector);
      return 0;
    }

    double* get_data() {
      return M->data;
    }
    

    /// Returns main or second diagonal (if p == true)
    virtual int get_diag(bool p, vector*& diag) {
      gsl_vector_view d = gsl_matrix_diagonal(M);
      diag = &d.vector;
      return 0;
    }
    
    /// Sets the specified row to the given vector
    virtual int set_row(size_t r, vector*& row) {
      return -1;
    }

    virtual int set_row_copy(size_t r, vector* row) {
      gsl_vector_view ro = gsl_matrix_row(M, r);
      gsl_vector_memcpy(&ro.vector, row);
      return 0;
    }
    /// Sets the specified col to the given vector
    virtual int set_col(size_t c, vector*& col) {
      return -1;
    }

    virtual int set_col_copy(size_t c, vector* col) {
      gsl_vector_view co = gsl_matrix_column(M, c);
      gsl_vector_memcpy(&co.vector, col);
      return 0;
    }

    /// Sets the specified diagonal to the given vector
    virtual int set_diag(bool p, vector*& d) {
      return -1;
    }
    
    /// Returns a submatrix that goes from row i1--i2 and cols j1--j2
    virtual int submatrix(size_t i1, size_t j1, size_t i2, size_t j2, matrix*& m) {
      return -1;
    }

    /// Returns a submatrix given the row-col index sets
    virtual int submatrix(gsl_vector_uint* I, gsl_vector_uint* J, matrix*& m) {
      return -1;
    }

    /** ****************************************************************
     *  Matrix Operations begin here. They are group separately into
     *  'self (essentially elementwise operations on itself)'
     *  'matrix-vector operations'
     *  'matrix-matrix operations'
     *  ****************************************************************
     */

    /** @name Elementwise
     * Operations on elements of this matrix
     */

    /*@{*/

    /// Vector l_p norms for this matrix
    virtual double norm (double p);
    virtual double norm (char*  p);

    /// Apply an arbitrary function elementwise to this matrix
    virtual double apply (double (* fn)(double));

    /// x_ij <- s * x_ij
    virtual int scale (double s);

    /// x_ij <- s + x_ij
    virtual int add_const(double s);

    /*@}*/

    
    /** @name RowColOps
     *  ****************************************************
     * Operations on columns of this matrix
     *  ****************************************************
     */

    /*@{*/
    virtual double col_norm(size_t c);
    virtual double col_norm (size_t c, double p);
    virtual double col_dot  (size_t c1, size_t c2);
    virtual double col_dot  (size_t c, vector* v);
    virtual int    col_add  (size_t c1, size_t c2);
    virtual int    col_sub  (size_t c1, size_t c2);
    virtual int    col_scale(size_t c1, double s);
    virtual int    col_sub  (size_t c1, vector* v);
    virtual int    col_add  (size_t c1, vector* v);
    virtual int    col_scale(size_t c1, double s, vector* r);
    virtual int    col_sum  (vector*);

    virtual double row_norm(size_t r);
    virtual double row_norm (size_t r, double p);
    virtual double row_dot  (size_t r1, size_t r2);
    virtual double row_dot  (size_t r, vector* v);
    virtual double row_dot(size_t r, vector* x, size_t* idx, size_t);
    virtual int    row_add  (size_t r1, size_t r2);
    virtual int    row_sub  (size_t r1, size_t r2);
    virtual int    row_scale(size_t c1, double s);
    virtual int    row_add  (size_t r1, vector* r);
    virtual int    row_sub  (size_t r1, vector* r);
    virtual int    row_scale(size_t c1, double s, vector* r);
    virtual int    row_daxpy(size_t i, double a, vector* r);
    virtual int    row_sum  (vector*);

    /// dot ( row i, col j)
    virtual double row_col_dot (size_t r, size_t c);
    /*@}*/

    /**
     * Matrix - vector operations
     */
    //virtual int dot(bool tranA, double* x, double* r);
    virtual int dot (bool tranA, vector* x, vector* r);

    /*@{*/
    
    /** 
     * Matrix - Matrix operations
     */
    // this <- this + b
    virtual int add (matrix* b);
    // this <- this - b
    virtual int sub (matrix* b);
    // this <- this . b
    virtual int dot (matrix* b);
    // this <- this ./ b
    virtual int div (matrix* b);
    
    // this <- this * b
    virtual int mul (matrix* b);

    // offensichtlich
    virtual void transpose();
    virtual matrix* transpose_copy();
    // Print out to file or stdout or whatever

    /// Normalize the columns of the matrix to have unit L2 norm
    virtual void normalize_columns();
    virtual void print();

#ifdef _INTERNAL_EIGS_
    /// Returns top 'k' eigenpairs 
    virtual eigenpairs* eigs (size_t k);
    
    /// Returns singular triplets
    virtual singular_triples* svds (size_t k);

#endif 
    int solve_psd_linear_system(vector* b, vector* x);
    int solve_least_squares(vector* b, vector* x);
    virtual int compute_AtA(matrix* result);
    /*@}*/




  };
}

#endif 
