// -*- c++ -*-
// Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)
// Time-stamp: <14 January 2004 03:28:37 PM CST --  suvrit>

// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#ifndef SPARSEMATRIX_H
#define SPARSEMATRIX_H

#include <iostream>
#include <fstream>
#include <string>

// Had to include this since I was adding extra functionality to this
// class. 
#include <gsl/gsl_vector.h>
#include <gsl/gsl_matrix.h>

#include "objective.h"

/**
 * CLASS: SparseMatrix
 * Author: Suvrit Sra
 * Implemented to deal with CCS files that we use a lot here (cs.utexas)
 * IDEA: I want to implement a function called 'apply' for the sparsematrix
 * that allows a sort of extension to the sparsematrix functionality while
 * still not sacrificing efficiency....
 * for e.g. this->apply(compute_obj, arg) will apply the function
 * compute_obj to each element of the sparse matrix. This is a fairly
 * restricted function but can be useful without exposing all the details
 * of the sparsematrix or without coming back and modifying the code of the
 * sparsematrix class once it is stable. Still this idea will pose some
 * implementation challenges so i will defer it for now.
 *
 */
class SparseMatrix {
private:
  size_t m_rows, m_cols;

  /* number of non-zeros */
  size_t m_nz;
     
  /* file name associated with this matrix */
  std::string fname;
  std::string txx;		
     
  /* The actual non-zeroes themselves */
  double* m_val;

  /* Colptrs for CCS structure */
  long* m_colptrs;
  
  /* Row indices for CCS structure */
  long* m_rowindx;
  
  /* Is the data in externally allocated arrays */
  bool m_is_external;

  /* Hack needed to do correct comp. of object. fn*/
  bool first_time;

  /* We compute norm only once to save time */
  double fnormA;

  /* Private function to carry out computation of fnorm of matrx */
  void compute_fnorm();
  bool fnorm_avail;		// is it avail or needs to be computed.
public:
  SparseMatrix() {
    m_rows = m_cols = m_nz = 0; first_time = false; fnorm_avail = false;
  }

  SparseMatrix(int r, int c, int n) { 
    //assert (r > 0 and c > 0 and n >= 0);
    m_rows = r; 
    m_cols = c;
    m_nz   = n;
    // Allocate the arrays
    m_colptrs = new long[c+1];
    m_rowindx = new long[n];
    m_val     = new double[n];
    m_is_external = false;
    first_time = false;
    fnorm_avail = false;
    fname = "";
    txx   = "";
  }
  
  //SparseMatrix (int* c, int* r, double* v, int ro, int co, int nz) :
  //  m_colptrs(c), m_rowindx(r), m_val(v),
  //  fname("factorization"), txx(""), m_is_external(true)
  //{ m_rows = ro; m_cols = co; m_nz = nz; }

  ~SparseMatrix() {
    //cout << "~SparseMatrix()" << endl;
    delete m_colptrs;
    delete m_rowindx;
    delete m_val;
  }

  SparseMatrix* clone();
  bool isExternal() const { return m_is_external;}
  int  read_ccs_file(char*);
  void makefull(gsl_matrix*);

  // This operator multiplies This . x 
  gsl_vector* operator * (const gsl_vector* x);

  //  v'col dot prod.
  double dot(int col, gsl_vector* v);

  // do a saxpy
  void saxpy(double alpha, int col, gsl_vector* v);

  // This does a dot product of col i with col j
  double col_dotprod(int i, int j);

  // This computes the cosine between col i and and col j.
  double col_cosine(int i, int j);

  // This computes the cosine of col i with input vector v
  double col_cosine2(int i, double* v);

  // Calculate the norm of column (i)
  double col_norm(int i);
  
  // return the frob norm of matrix
  double fnorm()  { if (fnorm_avail) return fnormA; else compute_fnorm(); return fnormA;}

  double col_diff(int i, int j);
  double col_delta(int, double*);

  std::string getFileName() const { return fname;}
  std::string getTxx()      const { return txx;}

  // Fill input vector with column of A
  void getcol(int, double*);


  void output_matlab(std::ofstream&, float);

  // Getting bulky with all these functions....well well
  double compute_obj(gsl_matrix*, tObjtype);

  // This calculates the transpose of this. Not yet implemented.
  SparseMatrix transpose();
     
  // This multiplies A^T . x
  gsl_vector* tran_times (const gsl_vector* x);

  // This calculates the matrix product this . B
  SparseMatrix& operator * (const SparseMatrix& B);

  // Write out matrix to output stream
  void print(std::ostream&);
  void printmat(std::ostream&);

  void dotdiv(SparseMatrix* l, SparseMatrix* r);
  void dotdiv(gsl_matrix*, SparseMatrix* r);

  // this * gsl_matrix
  void lmult(gsl_matrix*, gsl_matrix*,
		    bool trans1=false, bool trans2=false); 
  // gsl_matrix * this
  void rmult(gsl_matrix*, gsl_matrix*,
		    bool trans1=false, bool trans2=false);

  inline void set(int i, int j, double val) {
    for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++) {
      if (m_rowindx[t] == i) {
	m_val[t] = val;
	return;
      }
    }
    // If we are here a 0 is being destroyed and that will make life
    // bad. as of now just warn
    std::cerr << "A zero being destroyed at (" << i << ", " << j << ")" << std::endl;
  }
  

  void debuginfo();

  inline double dataAt(int i, int j) const {
    //assert(i >= 0 && i < numRows() && j >= 0 && j < numCols());
    for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++) {
      if (m_rowindx[t] == i) return m_val[t];
    }
    return 0;
  }

  inline double operator () (int i, int j) const  {
    //assert(i >= 0 && i < numRows() && j >= 0 && j < numCols());
    for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++)
      if (m_rowindx[t] == i) return m_val[t];
    return 0;
  }

  // This exposure was needed for libmyssvd
  long* getPointr() { return m_colptrs;}
  long* getIndx()   { return m_rowindx;}
  double* getVal()  { return m_val;}

  bool getObjFlag() const { return first_time;}
  void setObjFlag(bool f) { first_time = f;}

  int numRows() const { return m_rows;}
  int numCols() const { return m_cols;}
  int numNz  () const { return m_nz;  }
  int read_ccs_file(char* fname, char* txx);
};

#endif
