// File: matrix_float.cc
// Author: Suvrit Sra
// Time-stamp: <22 April 2006 11:33:35 AM CDT --  suvrit>
// Matrix base class implementation

#include "matrix_float.h"
#include <cstdio>

int SSLib::matrix_float::default_add (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  
  for (size_t i = 0; i < nrows(); i++)
    for (size_t j = 0; j < ncols(); j++) {
      double t = get(i, j) + b->get(i, j);
      set(i, j, t);
    }
  return 0;
}


int SSLib::matrix_float::default_sub (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;

  for (size_t i = 0; i < nrows(); i++)
    for (size_t j = 0; j < ncols(); j++) {
      double t = get(i, j) - b->get(i, j);
      set(i, j, t);
    }
  return 0;
}

/// TODO:
int SSLib::matrix_float::default_mul (matrix_float* b, bool tranA, bool tranB)
{
  if (!can_be_multiplied(b, tranA, tranB))
    return -1;

  return 0;
}

int SSLib::matrix_float::default_dot (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  
  for (size_t i = 0; i < nrows(); i++)
    for (size_t j = 0; j < ncols(); j++) {
      double t = get(i, j) * b->get(i, j);
      set(i, j, t);
    }
  return 0;
}

int SSLib::matrix_float::default_div (matrix_float* b)
{
  if (!dimension_match(b))
    return -1;
  
  for (size_t i = 0; i < nrows(); i++)
    for (size_t j = 0; j < ncols(); j++) {
      double t = get(i, j) / b->get(i, j);
      set(i, j, t);
    }
  return 0;
}


bool SSLib::matrix_float::dimension_match (matrix_float* b)
{
  if (nrows() != b->nrows() || ncols() != b->ncols()) {
    char t[256];
    sprintf(t, 
            "Dimension of 'this' = {%d, %d} don't match with 'b' = {%d, %d}",
            nrows(), ncols(), b->nrows(), b->ncols());
    std::string s = std::string(t);
    matrix_error(s);
    return false;
  }
  return true;
}

bool SSLib::matrix_float::can_be_multiplied (matrix_float* b, bool tranA, bool tranB)
{
  bool r = true;
  if (!tranA && !tranB) {
    if (ncols() != b->nrows())
      r = false;
  } else if (!tranA && tranB) {
    if (ncols() != b->ncols())
      r = false;
  } else if (tranA && !tranB) {
    if (nrows() != b->nrows())
      r = false;
  } else if (tranA && tranB) {
    if (nrows() != b->ncols())
      r = false;
  } else {}
  if (r) return r;
  std::string s = "Incompatible dimensions for multiplication";
  matrix_error(s);
  return r;
}

