https://github.com/OHBA-analysis/HMM-MAR
Tip revision: 83788633fd00d62e0bbc0fa9cc814242df944e4f authored by Diego Vidaurre on 04 July 2023, 08:54:57 UTC
Merge pull request #107 from sonsolesalonsomartinez/master
Merge pull request #107 from sonsolesalonsomartinez/master
Tip revision: 8378863
obsupdate.m
function hmm = obsupdate(Gamma,hmm,residuals,XX,XXGXX,Tfactor)
%
% Update observation model
%
% INPUT
% Gamma p(state given X)
% hmm hmm data structure
% residuals in case we train on residuals, the value of those.
%
% OUTPUT
% hmm estimated HMMMAR model
%
% Author: Diego Vidaurre, OHBA, University of Oxford
K = hmm.K;
obs_tol = 0.00005;
obs_maxit = 1; %20;
mean_change = Inf;
obs_it = 1;
p = hmm.train.lowrank; do_HMM_pca = (p > 0);
if nargin<7, Tfactor = 1; end
if ~isfield(hmm.train,'distribution') || strcmp(hmm.train.distribution,'Gaussian')
while mean_change>obs_tol && obs_it<=obs_maxit
last_state = hmm.state;
if do_HMM_pca
hmm = updatePCAparam (hmm,sum(Gamma),XXGXX,Tfactor);
else
%%% W
[hmm,XW] = updateW(hmm,Gamma,residuals,XX,XXGXX,Tfactor);
%%% Omega
hmm = updateOmega(hmm,Gamma,residuals,XX,XXGXX,XW,Tfactor);
%disp(num2str(hmm.Omega.Gam_rate / hmm.Omega.Gam_shape))
%%% autoregression coefficient priors
if (isfield(hmm.train,'V') && ~isempty(hmm.train.V))
%%% beta - one per regression coefficient
hmm = updateBeta(hmm);
else
%%% sigma - channel x channel coefficients
hmm = updateSigma(hmm);
%%% alpha - one per order
hmm = updateAlpha(hmm);
end
end
%%% termination conditions
obs_it = obs_it + 1;
mean_changew = 0;
for k = 1:K
mean_changew = mean_changew + ...
sum(sum(abs(last_state(k).W.Mu_W - hmm.state(k).W.Mu_W))) / numel(hmm.state(k).W.Mu_W) / K;
end
mean_change = mean_changew;
end
elseif strcmp(hmm.train.distribution,'logistic')
if isfield(hmm,'psi')
hmm = rmfield(hmm,'psi');
end
while mean_change>obs_tol && obs_it<=obs_maxit
last_state = hmm.state;
for iY = 1:hmm.train.logisticYdim
hmm_marginalised = logisticMarginaliseHMM(hmm,iY);
xdim = hmm_marginalised.train.ndim-1;
%%% W
[hmm_temp,~] = ...
updateW(hmm_marginalised,Gamma,residuals(:,iY),XX(:,[1:xdim,xdim+iY]),XXGXX,Tfactor);
%%% and hyperparameters alpha
hmm_temp = updateAlpha(hmm_temp);
hmm = logisticMergeHMM(hmm_temp,hmm,iY);
end
%%% termination conditions
mean_changew = 0;
for k=1:K
mean_changew = mean_changew + ...
sum(sum(abs(last_state(k).W.Mu_W - hmm.state(k).W.Mu_W))) / numel(hmm.state(k).W.Mu_W) / K;
end
mean_change = mean_changew;
obs_it = obs_it + 1;
end
else
hmm = updateW(hmm,Gamma,residuals,XX,XXGXX);
end
end