// File: dense_matrix_float.cc 
// Author: Suvrit Sra
// Time-stamp: <23 February 2010 10:51:30 AM CET --  suvrit>
// Implments the dense matrix using gsl_matrix


#include "dense_matrix_float.h"
#include "util.h"
#include <typeinfo>

int SSLib::dense_matrix_float::add (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  if (typeid(b) == typeid(this)) {
    gsl_matrix_float_add(M, ((dense_matrix_float*)b)->get_matrix());
    return 0;
  } else
    return default_add (b);
}

int SSLib::dense_matrix_float::sub (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  if (typeid(b) == typeid(this)) {
    gsl_matrix_float_sub(M, ((dense_matrix_float*)b)->get_matrix());
    return 0;
  } else
    return default_sub (b);
}


int SSLib::dense_matrix_float::dot (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  if (typeid(b) == typeid(this)) {
    gsl_matrix_float_mul_elements(M, ((dense_matrix_float*)b)->get_matrix());
    return 0;
  } else
    return default_dot (b);
}


int SSLib::dense_matrix_float::div (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  if (typeid(b) == typeid(this)) {
    gsl_matrix_float_div_elements(M, ((dense_matrix_float*)b)->get_matrix());
    return 0;
  } else
    return default_div (b);
}

int SSLib::dense_matrix_float::fload(const char* s, size_t m, size_t n)
{
   M = gsl_matrix_float_alloc(m, n);
   return SSUtil::read_gsl_matrix_float(M, s);
}

int SSLib::dense_matrix_float::load(const char* s)
{
  gsl_matrix_float* t = M;
  M = SSUtil::read_gsl_matrix_float(const_cast<char*>(s));
  if (M) {
    if (t != 0)
      gsl_matrix_float_free (t);
    matrix_setsize(M->size1, M->size2);
    return 0;
  } else {
    M = t;
    return 1;
  }
}

int SSLib::dense_matrix_float::load(const char* s, bool asbin)
{
  gsl_matrix_float* t = M;
  if (asbin)
    M = SSUtil::fread_gsl_matrix_float(const_cast<char*>(s));
  else 
    M = SSUtil::read_gsl_matrix_float(const_cast<char*>(s));

  if (M) {
    if (t != 0)
      gsl_matrix_float_free(t);
    matrix_setsize(M->size1, M->size2);
    return 0;
  } else {
    M = t;
    return 1;
  }
}

int SSLib::dense_matrix_float::save(const char* s, bool asbin, int typ)
{
  return -2;
}



int SSLib::dense_matrix_float::save(const char* s, bool asbin)
{
  return  SSUtil::fwrite_gsl_matrix_float(M, const_cast<char*>(s));
}				

/// Write out matrix to disk in CCS, CRS, etc. formats
int SSLib::dense_matrix_float::save_as_ccs(const char*)
{
  return -1;
}

int SSLib::dense_matrix_float::save_as_crs(const char*)
{
  return -1;

}

int SSLib::dense_matrix_float::save_as_coord(const char*)
{
  return -1;
}

int SSLib::dense_matrix_float::save_as_col_coord(const char*)
{
  return -1;
}

int SSLib::dense_matrix_float::save_as_matlab(const char*)
{
  return -1;
}

int SSLib::dense_matrix_float::save_as_hb(const char*)
{
  return -1;
}

int SSLib::dense_matrix_float::save_as_dense(const char*)
{
  return -1;
}

/* Elementwise functions on this matrix */


/// Vector l_p norms for this matrix
/// @todo Faster exponents
/// @param 'p', for lp norms.
double SSLib::dense_matrix_float::norm (double p)
{
  double nrm = 0.0;
  // we take advantage of non-negativity while computing the matrix norms. 
  // Also we compute exponents faster TODO
  if (is_non_negative()) {
    if (p == 1) {
      nrm = SSUtil::sum(M->data, nrows()*ncols());
    } else if (p == 2) {
      nrm = SSUtil::euclidean_norm(M->data, nrows()*ncols());
    } else {
      nrm = SSUtil::lp_norm(M->data, nrows()*ncols(), p);
    }
  } else {
    if (p == 1) {
      nrm = SSUtil::abs_sum(M->data, nrows()*ncols());
    } else if (p == 2) {
      nrm = SSUtil::euclidean_norm(M->data, nrows()*ncols());
    } else {
      nrm = SSUtil::lp_norm(M->data, nrows()*ncols(), p, true);
    }
  }
  return nrm;
}

/**
 * @param p specifies 'fro', 'l1', 'matrix-l1', 'inf', 'matrix-inf'
 * @return the specified norm
 * or return -1 if could not compute specified norm...
 */
double SSLib::dense_matrix_float::norm(char* p) 
{
  std::string s(p);
  if (s == "fro") {
    return norm(2);
  } else if (s == "l1") {
    return norm(1);
  } else if (s == "inf") {
    return SSUtil::linf_norm(M->data, nrows()*ncols());
  } else if (s == "matrix-l1") {
    matrix_error("Not yet implemented!");
  } else if (s == "matrix-inf") {
    matrix_error ("Not yet implemented");
  } else {
    matrix_error ("Invalid or unsupported norm requested");
  }
  return -1;
}

// Apply an arbitrary function elementwise to this matrix
float SSLib::dense_matrix_float::apply (float (* fn)(float))
{
  for (size_t i = 0; i < nrows()*ncols(); i++) {
    M->data[i] = fn(M->data[i]);
  }
  return 0.0;
}

/* Operations on columns of this matrix */


/**
 * @param c      the column whose norm to computer
 * @param p      compute l_p norm
 * @return
 *    -   -1 if invalid 'p' norm requested
 *    -   -2 if invalid column specified
 *    -   double value as the requested norm
 */
double SSLib::dense_matrix_float::col_norm (size_t c, double p)
{
  double n = 0.0;
  return n;
}

double SSLib::dense_matrix_float::col_dot  (size_t c1, size_t c2)
{
  float r;
  gsl_vector_float_view col1 = gsl_matrix_float_column(M, c1);
  gsl_vector_float_view col2 = gsl_matrix_float_column(M, c2);
  gsl_blas_sdot(&col1.vector, &col2.vector, &r);
  return r;
}

double SSLib::dense_matrix_float::col_dot  (size_t c, gsl_vector_float* v)
{
  float r;
  gsl_vector_float_view col = gsl_matrix_float_column(M, c);
  gsl_blas_sdot(&col.vector, v, &r);
  return r;
}

int    SSLib::dense_matrix_float::col_add  (size_t c1, size_t c2)
{
  return 0;
}

int    SSLib::dense_matrix_float::col_sub  (size_t c1, size_t c2)
{
  return 0;
}

int    SSLib::dense_matrix_float::col_scale(size_t c1, float s)
{
  return 0;
}

/**
 * @todo
 */
int SSLib::dense_matrix_float::col_sub (size_t c1, vector_float* c2)
{
  return 0;
}

/**
 * @todo
 */
int SSLib::dense_matrix_float::col_add (size_t c1, vector_float* c2)
{
  return 0;
}

/**
 * Scaled col c1, and stores result in a preallocated vector r
 * @return
 *    -    0 if computtaion was successful
 *    -    1 if dimension mismatch
 *    -    2 if invalid column requested
 */
int SSLib::dense_matrix_float::col_scale(size_t c1, float s, vector_float* r)
{
  return 0;
}

/* Operations on rows of this matrix */
double SSLib::dense_matrix_float::row_norm (size_t r, double p)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M,r);
  if (p == 1) {
    return gsl_blas_sasum(&row.vector);
  }
  if (p == 2) {
    return gsl_blas_snrm2(&row.vector);
  }
  return 0.0;
}

double SSLib::dense_matrix_float::row_dot  (size_t r1, size_t r2)
{
  gsl_vector_float_view row1 = gsl_matrix_float_row(M, r1);
  gsl_vector_float_view row2 = gsl_matrix_float_row(M, r2);
  float r;
  gsl_blas_sdot(&row1.vector, &row2.vector, &r);
  return r;
}

double SSLib::dense_matrix_float::row_dot  (size_t r, vector_float* v)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M, r);
  float res;
  gsl_blas_sdot(&row.vector, v, &res);
  return res;
}

int    SSLib::dense_matrix_float::row_add  (size_t r1, size_t r2)
{
  gsl_vector_float_view row1 = gsl_matrix_float_row(M, r1);
  gsl_vector_float_view row2 = gsl_matrix_float_row(M, r2);
  gsl_vector_float_add(&row1.vector, &row2.vector);
  return 0;
}

int    SSLib::dense_matrix_float::row_sub  (size_t r1, size_t r2)
{
  gsl_vector_float_view row1 = gsl_matrix_float_row(M, r1);
  gsl_vector_float_view row2 = gsl_matrix_float_row(M, r2);
  gsl_vector_float_sub(&row1.vector, &row2.vector);
  return 0;
}

int    SSLib::dense_matrix_float::row_scale(size_t c1, float s)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M,c1);
  gsl_vector_float_scale(&row.vector, s);
  return 0;
}
 
int SSLib::dense_matrix_float::row_add (size_t r1, vector_float* v)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M, r1);
  gsl_vector_float_add(&row.vector, v);
  return 0;
}

int SSLib::dense_matrix_float::row_sub (size_t r1, vector_float* v)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M, r1);
  gsl_vector_float_sub(&row.vector, v);
  return 0;
}

int SSLib::dense_matrix_float::row_scale(size_t c1, float s, vector_float* r)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M, c1);
  gsl_vector_float_set_zero(r);
  gsl_blas_saxpy(s, &row.vector, r);
  return 0;
}

int SSLib::dense_matrix_float::row_saxpy(size_t i, float a, gsl_vector_float* r)
{
  gsl_vector_float_view row = gsl_matrix_float_row(M,i);
  gsl_blas_saxpy(a, &row.vector, r);
  return 0;
}

/** @name MatrixFunctions
 *
 * Functions or operations involving entire matrix
 */
/*@{*/

/**
 * Function does result <- Ax
 * @param x The vector to be multiplied by this,
 * @param result the *UNALLOCATED* pointer result, which  will be allocated
 * to hold the result if the computation can be carried out.
 * @return
 *    -    0 if computation was ok
 *    -    1 if computation could not be carried otu (dimension mismatch)
 */
int SSLib::dense_matrix_float::dot (bool tranA, gsl_vector_float* x, 
                                    gsl_vector_float* result)
{
  if (ncols() != x->size)
    return 1;
  if (tranA)
    gsl_blas_sgemv(CblasTrans, 1.0, M, x, 0.0, result);
  else
    gsl_blas_sgemv(CblasNoTrans, 1.0, M, x, 0.0, result);
  return 0;
}


void SSLib::dense_matrix_float::transpose()
{
  return;
}

SSLib::matrix_float* SSLib::dense_matrix_float::transpose_copy()
{
  return 0;
}

int SSLib::dense_matrix_float::scale(float s)
{
  for (size_t i = 0; i < nrows()*ncols(); i++)
    M->data[i] *= s;
  return 0;
}

void SSLib::dense_matrix_float::print()
{
  for (size_t i = 0; i < nrows(); i++) {
    for (size_t j = 0; j < ncols(); j++)
        std::cout << get(i, j) << " ";
    std::cout << std::endl;
  }
	
  return;
}

int SSLib::dense_matrix_float::add_const(float s)
{
  for (size_t i = 0; i < nrows()*ncols(); i++)
    M->data[i] += s;
  return 0;
}

