function M = hpdfactory(n)
% Returns a manifold structure to optimize over hpd matrices.
%
% function M = hpdfactory(n)
%
% The HPD manifold is the Riemmanian manifold of hermitian nxn 
%    positive definite matrices. 
%
% Original authors: Suvrit Sra and Reshad Hosseini, Aug, 02, 2013
% Change log:
%  Reshad Hosseini, Aug,30,2013: Implementing retr, transp, ehess2rhess
%  Reshad Hosseini, Jan,14,2014: Improving speed of transp using sqrtm_fast
%  Reshad Hosseini, Jun,26,2014: Improving mbfgs speed by adding transpF
   

flag = true; % flag = true v. t. riemman ; flag=false: v. t. is identitty
% If flag is one then it corresponds to transp of natural metrix

if ~exist('n', 'var') || isempty(n)
    n = 1;
end

M.name = @() sprintf('HPD manifold (%d, %d)', n, n);

M.dim = @() n^2;

M.inner = @(X, U, V) real(sum(sum( (X\U).' .* (X\V) ))); %U(:).'*V(:);

M.norm = @(X, U)  sqrt(real(sum(sum( abs(X\U).^2 ))));

M.dist = @riem;
% Riemmanian distance
    function d = riem(X,Y)
        d = eig(X, Y);
        d = norm(log(d));
    end

M.typicaldist = @() n;

herm = @(X) (X+X')/2;

M.proj = @projection;
    function Up = projection(X, U)
        % Tangent space of hermitian matrices is also a hermitian matrix
        Up = herm(U);
    end

M.tangent = M.proj;

% For Riemannian submanifolds with euclidean inner product,
% converting a Euclidean gradient into a
% Riemannian gradient amounts to an orthogonal projection.
% Here the inner product is definted as tr(E X^-1 F X^-1). therefore
% We obtain the following for Riemmanian Gradient
M.egrad2rgrad = @egrad2regrad;
    function Up = egrad2regrad(X, U)
        Up = X * herm(U) * X;
        if 0
            % this gradient corresponding to euclidean innerproduct is slow
            Up = U;
        end
    end

M.ehess2rhess = @ehess2rhess;
    function Hess = ehess2rhess(X, egrad, ehess, eta)
        Hess = X*herm(ehess)*X + 2*herm(H*herm(egrad)*X);
        Hess = Hess - herm(eta*herm(egrad)*X);
    end

M.retr = @retraction;
    function Y = retraction(X, U, t)
        if nargin < 3
            t = 1.0;
        end
        if flag
            E = t*U;
            Y = X * expm(X\E);
            Y = herm(Y);
        else
            Y = X + t*U;
        end
    end

M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 2
            t = 1;
        end
        Y = retraction(X, U, t);
    end

M.log = @logarithm;
    function U = logarithm(X, Y)
        U = X*logm(X\Y);
        U = herm(U);
    end

M.hash = @(X) ['z' hashmd5(X(:))];

M.rand = @random;
    function X = random()
        X = randn(n)+1i*randn(n);
        X = (X*X');
    end

M.randvec = @randomvec;
    function U = randomvec(X)
        U = randn(n)+1i*randn(n);
        U = herm(U);
        U = U / norm(U,'fro');
    end

M.lincomb = @lincomb;

M.zerovec = @(x) zeros(n);

M.transp = @transpvec;
    function F = transpvec(X, Y, E)
        if flag
            if true
                expconstruct= sqrtm_fast(Y/X);
                F = expconstruct*E*expconstruct';
            else
                % Identity parallel transport works for LBFGS
                % There is also proof for the convergence
                F = E;
            end
        else
            % identity parallel transport
            F = E;
        end
    end

% applying vector transpord and save a variable for applying fast version
M.transpf = @transpvecf;
    function [F,expconstruct,iexpconstruct] = transpvecf(X, Y, E)
        if flag
            if true
                expconstruct= sqrtm_fast(Y/X);
                F = expconstruct*E*expconstruct';
                if nargout > 2
                   iexpconstruct = inv(expconstruct); 
                end
            else
                % Identity parallel transport works for LBFGS
                % There is also proof for the convergence
                F = E;
                if nargout > 1
                    expconstruct = eye(size(X,1));
                end
                if nargout > 2
                    iexpconstruct = eye(size(X,1));
                end
            end
        else
            % identity parallel transport
            F = E;
            if nargout > 1
                expconstruct = eye(size(X,1));
            end
            if nargout > 2
                iexpconstruct = eye(size(X,1));
            end
        end
    end

% inverse of vector transport
M.itransp = @itranspvec;
    function F = itranspvec(X, Y, E)
        F = transpvec(Y, X, E);
    end

% faster version of vector transport by storing some information
M.transpF = @transpvecfast; 
    function F = transpvecfast(expconstruct, E)
        if flag
            F = expconstruct*E*expconstruct';
        else
            F = E;
        end
    end
    
% faster version of inverse vector transport by storing some information
M.itranspF = @itranspvecfast; 
    function F = itranspvecfast(iexpconstruct, E)
        if flag
            F = iexpconstruct*E*iexpconstruct';
        else
            F = E;
        end
    end

M.vec = @(x, u_mat) u_mat(:);

M.mat = @(x, u_vec) reshape(u_vec, [n, n]);

M.vecmatareisometries = @() false;

end

% Linear combination of tangent vectors
function d = lincomb(x, a1, d1, a2, d2) %#ok<INUSL>

if nargin == 3
    d = a1*d1;
elseif nargin == 5
    d = a1*d1 + a2*d2;
else
    error('Bad use of psd.lincomb.');
end

end
