/* Function for calculating the square root of upper triangular matrix
 * Written by Reshad Hosseini*/
#include "mex.h"
#include <math.h> //needed for sqrt
#include <string.h> //needed for memcpy

void mexFunction(int nlhs, mxArray *plhs[],
        int nrhs, const mxArray *prhs[]) {
    int m, n, i, j, k;
    double *T, *Ts;
    
    if(nrhs != 1 || nlhs > 1)
        mexErrMsgTxt("Usage: Ts = sqrtm_triu(T)");
    
    /* prhs[0] is first argument.
     * mxGetPr returns double*  (data, col-major)
     * mxGetM returns int  (rows)
     * mxGetN returns int  (cols)
     */
    /* m = rows(T) */
    m = mxGetM(prhs[0]);
    n = mxGetN(prhs[0]);
    if(m != n) mexErrMsgTxt("matrix must be square");

    T = mxGetPr(prhs[0]);
    
    if(mxIsSparse(prhs[1])) {
        mexErrMsgTxt("Can not handle sparse matrices yet.");
    }
    if(mxGetNumberOfDimensions(prhs[0]) != 2) {
        mexErrMsgTxt("Arguments must be matrices.");
    }
    
    /* Set the output pointer to the output matrix. */
    plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
    
     /* Create a C pointer to a copy of the output matrix. */
    Ts = mxGetPr(plhs[0]);
    
    /* copy T into Ts to speed up memory access */
    memcpy(Ts, T, m*n*sizeof(double));

    /* Upper triangular */
    // u[i,j] = u[i + m*j]
    for(j=0;j<n;j++) Ts[j + m*j] = sqrt(T[j + m*j]);
    
    for(j=1;j<n;j++) {
        for(i=j-1;i>=0;i--) {
            double s = 0;
            for(k=i+1;k<j;k++) {
                s += Ts[i + m*k]*Ts[k + m*j];
            }
            //mexPrintf("DONE3 i=%i,j=%i \n",i,j);
            Ts[i + m*j] = (T[i + m*j] - s) / (Ts[i + m*i]+Ts[j + m*j]);
        }
    }
}