https://github.com/OHBA-analysis/HMM-MAR
Revision df84f129c4d6a2f36757f137da28466b7ae5c3d5 authored by robecker on 05 December 2016, 18:12:29 UTC, committed by robecker on 05 December 2016, 18:12:29 UTC
Tip revision: df84f129c4d6a2f36757f137da28466b7ae5c3d5 authored by robecker on 05 December 2016, 18:12:29 UTC
Merge branch 'master' of https://github.com/OHBA-analysis/HMM-MAR into robert_stochastic
Merge branch 'master' of https://github.com/OHBA-analysis/HMM-MAR into robert_stochastic
Tip revision: df84f12
mlhmmmar.m
function hmm = mlhmmmar (X,T,hmm0,Gamma,completelags)
% Given the state time courses estimation, does a last estimation of each MAR using (local) ML
%
% INPUT
% X observations
% T length of series
% hmm0 HMM-MAR structure
% Gamma p(state given X) - has to be fully defined
% completelags if 1, the lags are made linear with timelag=1 (i.e. a complete set)
%
% OUTPUT
% hmm HMM-MAR structure with the coefficients and covariance matrices updated to
% follow a maximum-likelihood estimation
% gcovm covariance matrix of the error for the entire model
%
% Author: Diego Vidaurre, OHBA, University of Oxford
hmm = hmm0;
if hmm.train.uniqueAR
error('mlhmmmar not yet implemented for uniqueAR')
end
if iscell(T)
T2 = cell2mat(T); T2 = T2(:);
else
T2 = T;
end
if nargin<4, Gamma = ones(sumT-length(T2)*hmm.train.order,1); end
if nargin<5, completelags = 0; end
ndim = size(hmm.state(1).W.Mu_W,2);
K = size(Gamma,2);
hmm.K = K; N = length(T);
if completelags
maxorder0 = hmm.train.maxorder;
hmm.train.orderoffset=0; hmm.train.timelag=1; hmm.train.exptimelag=0;
hmm.train.maxorder = hmm.train.order;
if hmm.train.multipleConf
for k=1:K
if isfield(hmm.state(k),'train') && ~isempty(hmm.state(k).train),
hmm.state(k).train.orderoffset=0;
hmm.state(k).train.timelag=1;
hmm.state(k).train.exptimelag=0;
hmm.train.maxorder = max(hmm.train.maxorder,hmm.state(k).train.order);
end
end
end
maxorderd = hmm.train.maxorder - maxorder0;
if maxorderd>0 % trim Gamma if maxorder has changed
Gamma0 = Gamma;
Gamma = zeros(sum(T)-hmm.train.maxorder*length(T),K);
for j = 1:N
t00 = sum(T(1:j-1)) - (j-1)*maxorder0 + 1;
t10 = sum(T(1:j)) - j*maxorder0;
t0 = sum(T(1:j-1)) - (j-1)*hmm.train.maxorder + 1;
t1 = sum(T(1:j)) - j*hmm.train.maxorder;
Gamma(t0:t1,:) = Gamma0(t00+maxorderd:t10,:);
end
end
end
[hmm.train.orders,order] = formorders(hmm.train.order,hmm.train.orderoffset,hmm.train.timelag,hmm.train.exptimelag);
if isfield(hmm.train,'B'), hmm.train.B = []; end
if isfield(hmm.train,'V'), hmm.train.V = []; end
if iscell(X)
c = 0;
for i = 1:N
[~,XX,Y] = loadfile(X{i},T{i},hmm.train);
ind = (1:sum(T{i})-length(T{i})*order) + c;
c = c + length(ind);
%else
% ind = sum(T(1:i-1)) + (1:T(i));
% Y = getresiduals(data(ind,:),T(i),Sind,hmm.train.maxorder,hmm.train.order,hmm.train.orderoffset,...
% hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
% XX = formautoregr(data.X,T,orders,hmm.train.maxorder,hmm.train.zeromean,0,B,V);
%end
if i==1
XX2 = zeros(size(XX,2),size(XX,2),K);
XXY = zeros(size(XX,2),ndim,K);
end
for k=1:K
XX2(:,:,k) = XX2(:,:,k) + (XX' .* repmat(Gamma(ind,k)',size(XX,2),1)) * XX;
XXY(:,:,k) = XXY(:,:,k) + (XX' .* repmat(Gamma(ind,k)',size(XX,2),1)) * Y;
end
end
for k=1:K
hmm.state(k).W.Mu_W = (XX2(:,:,k) + 1e-6 * eye(size(XX2,2))) \ XXY(:,:,k);
%hmm.state(k).W.Mu_W = XX2(:,:,k) \ XXY(:,:,k);
if isfield(hmm.state(k).W,'S_W'), hmm.state(k).W = rmfield(hmm.state(k).W,'S_W'); end
if isfield(hmm.state(k).W,'iS_W'), hmm.state(k).W = rmfield(hmm.state(k).W,'iS_W'); end
end
%sumT = 0; ge = zeros(ndim);
c = 0;
for i = 1:N
[~,XX,Y] = loadfile(X{i},T{i},hmm.train);
ind = (1:sum(T{i})-length(T{i})*order) + c;
c = c + length(ind);
%sumT = sumT + size(Y,1);
if i==1
if strcmp(hmm.train.covtype,'uniquediag')
hmm.Omega.Gam_rate = zeros(1,ndim);
hmm.Omega.Gam_shape = 0;
elseif strcmp(hmm.train.covtype,'uniquefull')
hmm.Omega.Gam_rate = zeros(ndim);
hmm.Omega.Gam_shape = 0;
elseif strcmp(hmm.train.covtype,'diag')
for k=1:K
hmm.state(k).Omega.Gam_rate = zeros(1,ndim);
hmm.state(k).Omega.Gam_shape = 0;
end
else % full
for k=1:K
hmm.state(k).Omega.Gam_rate = zeros(ndim);
hmm.state(k).Omega.Gam_shape = 0;
end
end
end
for k=1:K
e = Y - XX * hmm.state(k).W.Mu_W;
%ge = ge + e' * e;
if strcmp(hmm.train.covtype,'diag')
hmm.state(k).Omega.Gam_shape = hmm.state(k).Omega.Gam_shape + 0.5 * sum(Gamma(ind,k));
hmm.state(k).Omega.Gam_rate = hmm.state(k).Omega.Gam_rate + ...
0.5 * sum( repmat(Gamma(ind,k),1,ndim) .* e.^2 );
elseif strcmp(hmm.train.covtype,'full')
hmm.state(k).Omega.Gam_shape = hmm.state(k).Omega.Gam_shape + sum(Gamma(ind,k));
hmm.state(k).Omega.Gam_rate = hmm.state(k).Omega.Gam_rate + ...
(e' .* repmat(Gamma(ind,k)',ndim,1)) * e;
elseif strcmp(hmm.train.covtype,'uniquediag')
hmm.Omega.Gam_shape = hmm.Omega.Gam_shape + 0.5 * sum(Gamma(ind,k));
hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + 0.5 * sum( repmat(Gamma(ind,k),1,ndim) .* e.^2 );
else
hmm.Omega.Gam_shape = hmm.Omega.Gam_shape + sum(Gamma(ind,k));
hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + (e' .* repmat(Gamma(ind,k)',ndim,1)) * e;
end
end
end
else
S = hmm.train.S==1; regressed = sum(S,1)>0;
Sind = formindexes(hmm.train.orders,hmm.train.S); hmm.train.Sind = Sind;
if ~hmm.train.zeromean, Sind = [true(1,size(X,2)); Sind]; end
Y = getresiduals(X,T,Sind,hmm.train.maxorder,hmm.train.order,hmm.train.orderoffset,...
hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
pred = zeros(size(Y));
setxx; % build XX
for k=1:K
setstateoptions;
if hmm.train.uniqueAR
XY = zeros(size(XX{kk},1)*ndim,1);
XGX = zeros(size(XX{kk},2)/ndim,size(XX{kk},2)/ndim);
for n=1:ndim
ind = n:ndim:size(XX{kk},2);
iomegan = omega.Gam_shape / omega.Gam_rate(n);
XGX = XGX + iomegan * XXGXX{k}(ind,ind);
XY = XY + (iomegan * XX{kk}(:,ind)' .* repmat(Gamma(:,k)',sum(ind),1)) * Y(:,n);
end
hmm.state(k).W.Mu_W = XGX \ XY;
predk = XX{kk} * repmat(hmm.state(k).W.Mu_W,1,ndim);
elseif all(S(:)==1)
hmm.state(k).W.Mu_W = pinv(XX{kk} .* repmat(sqrt(Gamma(:,k)),1,size(XX{kk},2))) ...
* ( repmat(sqrt(Gamma(:,k)),1,size(Y,2)) .* Y);
predk = XX{kk} * hmm.state(k).W.Mu_W;
else
hmm.state(k).W.Mu_W = zeros(size(XX{kk},1),ndim);
for n=1:ndim
if ~regressed(n), continue; end
hmm.state(k).W.Mu_W(Sind(:,n),n) = pinv(XX{kk}(:,Sind(:,n)) .* ...
repmat(sqrt(Gamma(:,k)),1,sum(Sind(:,n)))) * Y(:,n);
end
predk = XX{kk} * hmm.state(k).W.Mu_W;
end
pred = pred + repmat(Gamma(:,k),1,ndim) .* predk;
e = Y(:,regressed) - predk(:,regressed);
if strcmp(hmm.train.covtype,'diag')
hmm.state(k).Omega.Gam_shape = 0.5 * sum(Gamma(:,k));
hmm.state(k).Omega.Gam_rate = zeros(1,ndim);
hmm.state(k).Omega.Gam_rate(regressed) = 0.5 * ...
sum( repmat(Gamma(:,k),1,sum(regressed)) .* e.^2 );
elseif strcmp(hmm.train.covtype,'full')
hmm.state(k).Omega.Gam_shape = sum(Gamma(:,k));
hmm.state(k).Omega.Gam_rate = zeros(ndim);
hmm.state(k).Omega.Gam_rate(regressed,regressed) = ...
(e' .* repmat(Gamma(:,k)',sum(regressed),1)) * e;
hmm.state(k).Omega.Gam_irate(regressed,regressed) = ...
inv(hmm.state(k).Omega.Gam_rate(regressed,regressed));
end
end
% if length(hmm.state) > K
% state = hmm.state(1:K);
% hmm = rmfield(hmm,'state');
% hmm.state = state;
% end
e = Y(:,regressed) - pred(:,regressed);
if strcmp(hmm.train.covtype,'uniquediag')
hmm.Omega.Gam_shape = 0.5 * sum(T);
hmm.Omega.Gam_rate = zeros(1,ndim);
hmm.Omega.Gam_rate(regressed) = 0.5 * sum( e.^2 );
elseif strcmp(hmm.train.covtype,'uniquefull')
hmm.Omega.Gam_shape = sum(T);
hmm.Omega.Gam_rate = zeros(ndim);
hmm.Omega.Gam_rate(regressed,regressed) = (e' * e);
hmm.Omega.Gam_irate(regressed,regressed) = inv(hmm.Omega.Gam_rate(regressed,regressed));
end
%gcovm = (ge' * ge) / size(Y,1);
end
end
Computing file changes ...