https://github.com/OHBA-analysis/HMM-MAR
Raw File
Tip revision: 1ef15b26afaa7c4766f45409864c12748c634949 authored by Diego Vidaurre on 21 June 2017, 13:49:58 UTC
Create LICENSE
Tip revision: 1ef15b2
mlhmmmar.m
function hmm = mlhmmmar (X,T,hmm0,Gamma,completelags)
% Given the state time courses estimation, does a last estimation of each MAR using (local) ML
%
% INPUT
% X             observations
% T             length of series
% hmm0          HMM-MAR structure
% Gamma         p(state given X) - has to be fully defined
% completelags  if 1, the lags are made linear with timelag=1 (i.e. a complete set)

%
% OUTPUT
% hmm           HMM-MAR structure with the coefficients and covariance matrices updated to
%                   follow a maximum-likelihood estimation
% gcovm         covariance matrix of the error for the entire model
%
% Author: Diego Vidaurre, OHBA, University of Oxford
 
hmm = hmm0;

if hmm.train.uniqueAR
    error('mlhmmmar not yet implemented for uniqueAR')
end

if iscell(T)
    T2 = cell2mat(T); T2 = T2(:);
else
    T2 = T;
end
if nargin<4, Gamma = ones(sumT-length(T2)*hmm.train.order,1); end
if nargin<5, completelags = 0; end
    
K = size(Gamma,2);
hmm.K = K; N = length(T);

if completelags
    maxorder0 = hmm.train.maxorder;
    hmm.train.orderoffset=0; hmm.train.timelag=1; hmm.train.exptimelag=0;
    hmm.train.maxorder = hmm.train.order;
    maxorderd = hmm.train.maxorder - maxorder0;
    if maxorderd>0 % trim Gamma if maxorder has changed
        Gamma0 = Gamma;
        Gamma = zeros(sum(T)-hmm.train.maxorder*length(T),K);
        for j = 1:N
            t00 = sum(T(1:j-1)) - (j-1)*maxorder0 + 1;
            t10 = sum(T(1:j)) - j*maxorder0;
            t0 = sum(T(1:j-1)) - (j-1)*hmm.train.maxorder + 1;
            t1 = sum(T(1:j)) - j*hmm.train.maxorder;
            Gamma(t0:t1,:) = Gamma0(t00+maxorderd:t10,:);
        end
    end
end

[hmm.train.orders,order] = formorders(hmm.train.order,hmm.train.orderoffset,hmm.train.timelag,hmm.train.exptimelag);

for k=1:K
    if ~isfield(hmm.state(k),'train') || isempty(hmm.state(k).train)
        hmm.state(k).train = hmm.train;
    end
end

if isfield(hmm.train,'B'), hmm.train.B = []; end
if isfield(hmm.train,'V'), hmm.train.V = []; end

if iscell(X)
    c = 0;
    for i = 1:N
        [~,XX,Y] = loadfile(X{i},T{i},hmm.train);
        if i==1, ndim = size(Y,2); end
        ind = (1:sum(T{i})-length(T{i})*order) + c;
        c = c + length(ind);
        %else
        %    ind = sum(T(1:i-1)) + (1:T(i));
        %    Y = getresiduals(data(ind,:),T(i),Sind,hmm.train.maxorder,hmm.train.order,hmm.train.orderoffset,...
        %        hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
        %    XX = formautoregr(data.X,T,orders,hmm.train.maxorder,hmm.train.zeromean,0,B,V);
        %end
        if i==1
            XX2 = zeros(size(XX,2),size(XX,2),K);
            XXY = zeros(size(XX,2),ndim,K);
        end
        for k=1:K
            XX2(:,:,k) = XX2(:,:,k) + (XX' .* repmat(Gamma(ind,k)',size(XX,2),1)) * XX;
            XXY(:,:,k) = XXY(:,:,k) + (XX' .* repmat(Gamma(ind,k)',size(XX,2),1)) * Y;
        end
    end
    for k=1:K
        hmm.state(k).W.Mu_W = (XX2(:,:,k) + 1e-6 * eye(size(XX2,2))) \ XXY(:,:,k);
        %hmm.state(k).W.Mu_W = XX2(:,:,k) \ XXY(:,:,k);
        if isfield(hmm.state(k).W,'S_W'), hmm.state(k).W = rmfield(hmm.state(k).W,'S_W'); end
        if isfield(hmm.state(k).W,'iS_W'), hmm.state(k).W = rmfield(hmm.state(k).W,'iS_W'); end
    end
    %sumT = 0; ge = zeros(ndim);
    c = 0;
    for i = 1:N
        [~,XX,Y] = loadfile(X{i},T{i},hmm.train);
        ind = (1:sum(T{i})-length(T{i})*order) + c;
        c = c + length(ind);
        %sumT = sumT + size(Y,1); 
        if i==1
            if strcmp(hmm.train.covtype,'uniquediag')
                hmm.Omega.Gam_rate = zeros(1,ndim); 
                hmm.Omega.Gam_shape = 0; 
            elseif strcmp(hmm.train.covtype,'uniquefull')
                hmm.Omega.Gam_rate = zeros(ndim); 
                hmm.Omega.Gam_shape = 0; 
            elseif strcmp(hmm.train.covtype,'diag')
                for k=1:K
                    hmm.state(k).Omega.Gam_rate = zeros(1,ndim);
                    hmm.state(k).Omega.Gam_shape = 0;
                end
            else % full
                for k=1:K
                    hmm.state(k).Omega.Gam_rate = zeros(ndim);
                    hmm.state(k).Omega.Gam_shape = 0;
                end
            end
        end
        for k=1:K
            e = Y - XX * hmm.state(k).W.Mu_W;
            %ge = ge + e' * e;
            if strcmp(hmm.train.covtype,'diag')
                hmm.state(k).Omega.Gam_shape = hmm.state(k).Omega.Gam_shape + 0.5 * sum(Gamma(ind,k));
                hmm.state(k).Omega.Gam_rate = hmm.state(k).Omega.Gam_rate + ...
                    0.5 *  sum( repmat(Gamma(ind,k),1,ndim) .* e.^2 );
            elseif strcmp(hmm.train.covtype,'full')
                hmm.state(k).Omega.Gam_shape = hmm.state(k).Omega.Gam_shape + sum(Gamma(ind,k));
                hmm.state(k).Omega.Gam_rate = hmm.state(k).Omega.Gam_rate + ...
                    (e' .* repmat(Gamma(ind,k)',ndim,1)) * e;
            elseif strcmp(hmm.train.covtype,'uniquediag')
                hmm.Omega.Gam_shape = hmm.Omega.Gam_shape + 0.5 * sum(Gamma(ind,k));
                hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + 0.5 *  sum( repmat(Gamma(ind,k),1,ndim) .* e.^2 );
            else
                hmm.Omega.Gam_shape = hmm.Omega.Gam_shape + sum(Gamma(ind,k));
                hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + (e' .* repmat(Gamma(ind,k)',ndim,1)) * e;
            end
        end
 
    end
    
else
    
    ndim = size(X,2);
    S = hmm.train.S==1; regressed = sum(S,1)>0;
    Sind = formindexes(hmm.train.orders,hmm.train.S); hmm.train.Sind = Sind;
    if ~hmm.train.zeromean, Sind = [true(1,size(X,2)); Sind]; end
    Y = getresiduals(X,T,Sind,hmm.train.maxorder,hmm.train.order,hmm.train.orderoffset,...
        hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
    pred = zeros(size(Y));
    setxx; % build XX
    
    for k=1:K
        setstateoptions;        
        if hmm.train.uniqueAR
            XY = zeros(size(XX{kk},1)*ndim,1);
            XGX = zeros(size(XX{kk},2)/ndim,size(XX{kk},2)/ndim);
            for n=1:ndim
                ind = n:ndim:size(XX{kk},2);
                iomegan = omega.Gam_shape / omega.Gam_rate(n);
                XGX = XGX + iomegan * XXGXX{k}(ind,ind);
                XY = XY + (iomegan * XX{kk}(:,ind)' .* repmat(Gamma(:,k)',sum(ind),1)) * Y(:,n);
            end
            hmm.state(k).W.Mu_W = XGX \ XY;
            predk = XX{kk} * repmat(hmm.state(k).W.Mu_W,1,ndim);
        elseif all(S(:)==1)
            hmm.state(k).W.Mu_W = pinv(XX{kk} .* repmat(sqrt(Gamma(:,k)),1,size(XX{kk},2))) ...
                * ( repmat(sqrt(Gamma(:,k)),1,size(Y,2)) .* Y);
            predk = XX{kk} * hmm.state(k).W.Mu_W;
        else
            hmm.state(k).W.Mu_W = zeros(size(XX{kk},1),ndim);
            for n=1:ndim
                if ~regressed(n), continue; end
                hmm.state(k).W.Mu_W(Sind(:,n),n) = pinv(XX{kk}(:,Sind(:,n)) .* ...
                    repmat(sqrt(Gamma(:,k)),1,sum(Sind(:,n)))) * Y(:,n);
            end
            predk = XX{kk} * hmm.state(k).W.Mu_W;
        end
        
        pred = pred + repmat(Gamma(:,k),1,ndim) .* predk;
        e = Y(:,regressed) - predk(:,regressed);
        if strcmp(hmm.train.covtype,'diag')
            hmm.state(k).Omega.Gam_shape = 0.5 * sum(Gamma(:,k));
            hmm.state(k).Omega.Gam_rate = zeros(1,ndim);
            hmm.state(k).Omega.Gam_rate(regressed) = 0.5 * ...
                sum( repmat(Gamma(:,k),1,sum(regressed)) .* e.^2 );
        elseif strcmp(hmm.train.covtype,'full')
            hmm.state(k).Omega.Gam_shape = sum(Gamma(:,k));
            hmm.state(k).Omega.Gam_rate = zeros(ndim);
            hmm.state(k).Omega.Gam_rate(regressed,regressed) =  ...
                (e' .* repmat(Gamma(:,k)',sum(regressed),1)) * e;
            hmm.state(k).Omega.Gam_irate(regressed,regressed) = ...
                inv(hmm.state(k).Omega.Gam_rate(regressed,regressed));
        end
    end
    % if length(hmm.state) > K
    %     state = hmm.state(1:K);
    %     hmm = rmfield(hmm,'state');
    %     hmm.state = state;
    % end
    
    e = Y(:,regressed) - pred(:,regressed);
    if strcmp(hmm.train.covtype,'uniquediag')
       hmm.Omega.Gam_shape = 0.5 * sum(T);
       hmm.Omega.Gam_rate = zeros(1,ndim);
       hmm.Omega.Gam_rate(regressed) = 0.5 * sum( e.^2 );
    elseif strcmp(hmm.train.covtype,'uniquefull')
       hmm.Omega.Gam_shape = sum(T);
       hmm.Omega.Gam_rate = zeros(ndim);
       hmm.Omega.Gam_rate(regressed,regressed) =  (e' * e);
       hmm.Omega.Gam_irate(regressed,regressed) = inv(hmm.Omega.Gam_rate(regressed,regressed));
    end
    %gcovm = (ge' * ge) / size(Y,1);
    
end

end

back to top