https://github.com/OHBA-analysis/HMM-MAR
Tip revision: 83788633fd00d62e0bbc0fa9cc814242df944e4f authored by Diego Vidaurre on 04 July 2023, 08:54:57 UTC
Merge pull request #107 from sonsolesalonsomartinez/master
Merge pull request #107 from sonsolesalonsomartinez/master
Tip revision: 8378863
obsinit.m
function [hmm, residuals, W0] = obsinit (data,T,hmm,Gamma)
%
% Initialise observation model in HMM
%
% INPUT
% data observations - a struct with X (time series) and C (classes)
% T length of series
% Gamma p(state given X)
% hmm hmm data structure
%
% OUTPUT
% hmm estimated HMMMAR model
% residuals in case we train on residuals, the value of those
% W0 global MAR estimation
%
% Author: Diego Vidaurre, OHBA, University of Oxford
if nargin<4, Gamma = []; end
do_HMM_pca = (hmm.train.lowrank > 0);
if ~do_HMM_pca
[residuals,W0] = getresiduals(data.X,T,hmm.train.S,hmm.train.maxorder,hmm.train.order,...
hmm.train.orderoffset,hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
else
residuals = []; W0 = [];
end
hmm = initpriors(data.X,T,hmm,residuals);
hmm = initpost(data.X,T,hmm,residuals,Gamma);
end
function hmm = initpriors(X,T,hmm,residuals)
% define priors
ndim = size(X,2);
rangresiduals2 = (range(residuals)/2).^2;
sharedcovmat = strcmpi(hmm.train.covtype,'uniquediag') || ...
strcmpi(hmm.train.covtype,'uniquefull') || ...
strcmpi(hmm.train.covtype,'shareddiag') || ...
strcmpi(hmm.train.covtype,'sharedfull');
if isfield(hmm.train,'B'), Q = size(hmm.train.B,2);
else Q = ndim;
end
pcapred = hmm.train.pcapred>0;
if pcapred, M = hmm.train.pcapred; end
p = hmm.train.lowrank; do_HMM_pca = (p > 0);
rangeK = 1:hmm.K;
% if hmm.train.episodic
% rangeK = 1:hmm.K+1;
% else
% rangeK = 1:hmm.K;
% end
for k = rangeK
train = hmm.train;
orders = train.orders;
%train.orders = formorders(train.order,train.orderoffset,train.timelag,train.exptimelag);
if ~isfield(train,'distribution') || strcmp(train.distribution,'Gaussian')
if (strcmp(train.covtype,'diag') || strcmp(train.covtype,'full')) && pcapred
st = struct('beta',[],'Omega',[],'Mean',[]);
elseif do_HMM_pca
st = struct('Omega',[]); %'beta',[]
elseif (strcmp(train.covtype,'diag') || strcmp(train.covtype,'full'))
st = struct('sigma',[],'alpha',[],'Omega',[],'Mean',[]);
elseif sharedcovmat && pcapred
st = struct('beta',[],'Mean',[]);
else
st = struct('sigma',[],'alpha',[],'Mean',[]);
end
elseif any(strcmp(train.distribution,{'logistic','poisson','bernoulli'}))
st = struct('alpha',[]);
end
if hmm.train.episodic
st.P = []; st.Pi = []; st.Dir_alpha = []; st.Dir2d_alpha = [];
end
defstateprior(k) = st;
if ~isfield(train,'distribution') || strcmp(train.distribution,'Gaussian')
if do_HMM_pca
%defstateprior(k).beta = struct('Gam_shape',[],'Gam_rate',[]);
%defstateprior(k).beta.Gam_shape = 0.1;
%defstateprior(k).beta.Gam_rate = 0.1 * ones(1,p);
elseif pcapred
defstateprior(k).beta = struct('Gam_shape',[],'Gam_rate',[]);
defstateprior(k).beta.Gam_shape = 0.1; %+ 0.05*eye(ndim);
defstateprior(k).beta.Gam_rate = 0.1 * ones(M,ndim);% + 0.05*eye(ndim);
else
if ~train.uniqueAR && isempty(train.prior)
defstateprior(k).sigma = struct('Gam_shape',[],'Gam_rate',[]);
defstateprior(k).sigma.Gam_shape = 0.1*ones(Q,ndim); %+ 0.05*eye(ndim);
defstateprior(k).sigma.Gam_rate = 0.1*ones(Q,ndim);% + 0.05*eye(ndim);
end
if ~isempty(orders) && isempty(train.prior)
defstateprior(k).alpha = struct('Gam_shape',[],'Gam_rate',[]);
defstateprior(k).alpha.Gam_shape = 0.1;
defstateprior(k).alpha.Gam_rate = 0.1*ones(1,length(orders));
end
end
if ~train.zeromean
defstateprior(k).Mean = struct('Mu',[],'iS',[]);
defstateprior(k).Mean.Mu = zeros(ndim,1);
defstateprior(k).Mean.S = rangresiduals2';
defstateprior(k).Mean.iS = 1./rangresiduals2';
end
if isempty(hmm.train.priorcov_rate)
priorcov_rate = rangeerror(X,T,residuals,orders,hmm.train);
else
priorcov_rate = hmm.train.priorcov_rate * ones(1,ndim);
end
if strcmp(train.covtype,'full')
defstateprior(k).Omega.Gam_rate = diag(priorcov_rate);
defstateprior(k).Omega.Gam_shape = ndim+0.1-1;
elseif strcmp(train.covtype,'diag')
defstateprior(k).Omega.Gam_rate = 0.5 * priorcov_rate;
defstateprior(k).Omega.Gam_shape = 0.5 * (ndim+0.1-1);
end
elseif any(strcmp(train.distribution,{'logistic','poisson'}))
% conjugate prior is Gamma distribution with shape and rate parameters:
defstateprior(k).alpha = struct('Gam_shape',[],'Gam_rate',[]);
defstateprior(k).alpha.Gam_shape = 0.1;
defstateprior(k).alpha.Gam_rate = 0.1;
elseif strcmp(train.distribution,'bernoulli')
% conjugate prior is Beta distribution with parameters a and b:
defstateprior(k).alpha = struct('a',[],'b',[]);
defstateprior(k).alpha.a = 1;
defstateprior(k).alpha.b = 1;
end
if hmm.train.episodic && k < hmm.K+1
defstateprior(k).Dir_alpha = hmm.state(k).prior.Dir_alpha;
defstateprior(k).Dir2d_alpha = hmm.state(k).prior.Dir2d_alpha;
end
end
if ~isfield(train,'distribution') || strcmp(train.distribution,'Gaussian')
if strcmp(hmm.train.covtype,'uniquefull') || strcmp(hmm.train.covtype,'sharedfull')
hmm.prior.Omega.Gam_shape = ndim+0.1-1;
hmm.prior.Omega.Gam_rate = diag(priorcov_rate);
elseif do_HMM_pca
hmm.prior.Omega.Gam_shape = 0.5 * (ndim+0.1-1);
hmm.prior.Omega.Gam_rate = 0.5 * median(priorcov_rate);
elseif strcmp(hmm.train.covtype,'uniquediag') || strcmp(hmm.train.covtype,'shareddiag')
hmm.prior.Omega.Gam_shape = 0.5 * (ndim+0.1-1);
hmm.prior.Omega.Gam_rate = 0.5 * priorcov_rate;
end
end
% assigning default priors for observation models
if ~isfield(hmm,'state') || ~isfield(hmm.state,'prior')
for k = rangeK
hmm.state(k).prior = defstateprior(k);
end
else
for k = rangeK
% prior not specified are set to default
statepriorlist = fieldnames(defstateprior(k));
fldname = fieldnames(hmm.state(k).prior);
% try fldname = fieldnames(hmm.state(k).prior);
% catch, keyboard; %hmm.state(k).prior = defstateprior(k);
% end
misfldname = find(~ismember(statepriorlist,fldname));
for i = 1:length(misfldname)
if k==hmm.K+1 && (strcmp(statepriorlist{i},'Dir2d_alpha') || strcmp(statepriorlist{i},'Dir_alpha'))
continue;
end
priorval = getfield(defstateprior(k),statepriorlist{i});
hmm.state(k).prior = setfield(hmm.state(k).prior,statepriorlist{i},priorval);
end
end
end
end
function hmm = initpost(X,T,hmm,residuals,Gamma)
% Initialising the posteriors
Tres = sum(T) - length(T)*hmm.train.maxorder;
ndim = size(X,2);
K = hmm.K;
hmm.train.active = ones(1,K);
pcapred = hmm.train.pcapred>0;
p = hmm.train.lowrank; do_HMM_pca = (p > 0);
episodic = hmm.train.episodic;
setxx; % build XX and get orders
setstateoptions;
if length(hmm.train.embeddedlags_batched) > 1
Gamma = Gamma_unwrapped;
end
Gammasum = sum(Gamma); % if episodic, Gammasum doesn't sum up to T
if ~isfield(train,'distribution') || any(strcmp(train.distribution,{'Gaussian','logistic'}))
% W
if episodic
hmm = initW_ehmm(hmm,XX,XXGXX,residuals,Gamma,Sind);
else
hmm = initW_hmm(hmm,XX,XXGXX,residuals,Gamma);
end
% Omega
if (strcmp(hmm.train.covtype,'uniquediag') || strcmp(hmm.train.covtype,'shareddiag')) ...
&& episodic
if hmm.train.uniqueAR, error('Not yet implemented'); end
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape + Tres / 2;
e = residuals(:,regressed) - computeStateResponses(XX,hmm,Gamma);
hmm.Omega.Gam_rate = zeros(1,ndim);
hmm.Omega.Gam_rate(regressed) = hmm.prior.Omega.Gam_rate(regressed) + 0.5 * sum(e.^2);
elseif (strcmp(hmm.train.covtype,'uniquefull') || strcmp(hmm.train.covtype,'sharedfull')) ...
&& episodic
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape + Tres;
e = residuals(:,regressed) - computeStateResponses(XX,hmm,Gamma);
hmm.Omega.Gam_rate = zeros(ndim,ndim); hmm.Omega.Gam_irate = zeros(ndim,ndim);
hmm.Omega.Gam_rate(regressed,regressed) = hmm.prior.Omega.Gam_rate(regressed,regressed) + e' * e;
hmm.Omega.Gam_irate(regressed,regressed) = inv(hmm.Omega.Gam_rate(regressed,regressed));
elseif (strcmp(hmm.train.covtype,'uniquediag') || strcmp(hmm.train.covtype,'shareddiag')) ...
&& hmm.train.uniqueAR
hmm.Omega.Gam_rate = hmm.prior.Omega.Gam_rate;
for k = 1:K
XW = zeros(size(XX,1),ndim);
for n = 1:ndim
ind = n:ndim:size(XX,2);
XW(:,n) = XX(:,ind) * hmm.state(k).W.Mu_W;
end
e = (residuals - XW).^2;
hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + ...
0.5 * sum( repmat(Gamma(:,k),1,ndim) .* e );
end
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape + Tres / 2;
elseif do_HMM_pca
hmm.Omega.Gam_rate = hmm.prior.Omega.Gam_rate;
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape;
v = hmm.Omega.Gam_rate / hmm.Omega.Gam_shape;
for k = 1:K
W = hmm.state(k).W.Mu_W;
M = W' * W + v * eye(p); % posterior dist of the precision matrix
omega_i = mean(diag(XXGXX{k} - XXGXX{k} * W * (M \ W')));
%e = sum(repmat(Gamma(:,k),1,ndim) .* (XX - XX * W * W').^2,2);
hmm.Omega.Gam_rate_state(k) = 0.5 * omega_i; %sum(e);
hmm.Omega.Gam_rate = hmm.Omega.Gam_rate + hmm.Omega.Gam_rate_state(k);
hmm.Omega.Gam_shape_state(k) = 0.5 * Gammasum(k);% * ndim;
hmm.Omega.Gam_shape = hmm.Omega.Gam_shape + hmm.Omega.Gam_shape_state(k);
end
elseif strcmp(hmm.train.covtype,'uniquediag') || strcmp(hmm.train.covtype,'shareddiag')
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape + Tres / 2;
hmm.Omega.Gam_rate = zeros(1,ndim);
hmm.Omega.Gam_rate(regressed) = hmm.prior.Omega.Gam_rate(regressed);
for k = 1:K
if ~isempty(hmm.state(k).W.Mu_W(:,regressed))
e = residuals(:,regressed) - XX * hmm.state(k).W.Mu_W(:,regressed);
else
e = residuals(:,regressed);
end
hmm.Omega.Gam_rate(regressed) = hmm.Omega.Gam_rate(regressed) + ...
0.5 * sum( repmat(Gamma(:,k),1,sum(regressed)) .* e.^2 );
end
elseif strcmp(hmm.train.covtype,'uniquefull') || strcmp(hmm.train.covtype,'sharedfull')
hmm.Omega.Gam_shape = hmm.prior.Omega.Gam_shape + Tres;
hmm.Omega.Gam_rate = zeros(ndim,ndim); hmm.Omega.Gam_irate = zeros(ndim,ndim);
hmm.Omega.Gam_rate(regressed,regressed) = hmm.prior.Omega.Gam_rate(regressed,regressed);
for k = 1:K
if ~isempty(hmm.state(k).W.Mu_W(:,regressed))
e = residuals(:,regressed) - XX * hmm.state(k).W.Mu_W(:,regressed);
else
e = residuals(:,regressed);
end
hmm.Omega.Gam_rate(regressed,regressed) = hmm.Omega.Gam_rate(regressed,regressed) + ...
(e' .* repmat(Gamma(:,k)',sum(regressed),1)) * e;
end
hmm.Omega.Gam_irate(regressed,regressed) = inv(hmm.Omega.Gam_rate(regressed,regressed));
elseif ~isfield(hmm.train,'distribution') || ~strcmp(hmm.train.distribution,'logistic') % state dependent
for k = 1:K
setstateoptions;
if train.uniqueAR
XW = zeros(size(XX,1),ndim);
for n=1:ndim
ind = n:ndim:size(XX,2);
XW(:,n) = XX(:,ind) * hmm.state(k).W.Mu_W;
end
e = (residuals - XW).^2;
hmm.state(k).Omega.Gam_rate = hmm.state(k).prior.Omega.Gam_rate + ...
0.5* sum( repmat(Gamma(:,k),1,ndim) .* e );
hmm.state(k).Omega.Gam_shape = hmm.state(k).prior.Omega.Gam_shape + Gammasum(k) / 2;
elseif strcmp(train.covtype,'diag')
if ~isempty(hmm.state(k).W.Mu_W)
e = (residuals(:,regressed) - XX * hmm.state(k).W.Mu_W(:,regressed)).^2;
else
e = residuals(:,regressed).^2;
end
hmm.state(k).Omega.Gam_rate = zeros(1,ndim);
hmm.state(k).Omega.Gam_rate(regressed) = hmm.state(k).prior.Omega.Gam_rate(regressed) + ...
sum( repmat(Gamma(:,k),1,sum(regressed)) .* e ) / 2;
hmm.state(k).Omega.Gam_shape = hmm.state(k).prior.Omega.Gam_shape + Gammasum(k) / 2;
else % full
if ~isempty(hmm.state(k).W.Mu_W)
e = residuals(:,regressed) - XX * hmm.state(k).W.Mu_W(:,regressed);
else
e = residuals(:,regressed);
end
hmm.state(k).Omega.Gam_shape = hmm.state(k).prior.Omega.Gam_shape + Gammasum(k);
hmm.state(k).Omega.Gam_rate = zeros(ndim,ndim); hmm.state(k).Omega.Gam_irate = zeros(ndim,ndim);
hmm.state(k).Omega.Gam_rate(regressed,regressed) = ...
hmm.state(k).prior.Omega.Gam_rate(regressed,regressed) + ...
(e' .* repmat(Gamma(:,k)',sum(regressed),1)) * e;
hmm.state(k).Omega.Gam_irate(regressed,regressed) = ...
inv(hmm.state(k).Omega.Gam_rate(regressed,regressed));
end
end
end
else
for k=1:K
hmm.state(1).W.Mu_W = [];
end
if ~strcmp(train.distribution,'logistic')
[hmm,XW] = updateW(hmm,Gamma,residuals,XX,XXGXX);
else
for k = 1:K
hmm.state(k).alpha.Gam_rate = 0.1*ones(1,ndim);
hmm.state(k).alpha.Gam_shape = 0.1;
end
[hmm,XW] = updateW(hmm,Gamma,residuals(:,end),XX,XXGXX);
end
end
% Priors over the parameters
if ~pcapred && ~do_HMM_pca
for k = 1:K
if hmm.train.order>0 && isempty(hmm.train.prior)
hmm.state(k).alpha.Gam_shape = hmm.state(k).prior.alpha.Gam_shape;
hmm.state(k).alpha.Gam_rate = hmm.state(k).prior.alpha.Gam_rate;
end
end
if episodic
%%% sigma - channel x channel coefficients
hmm = updateSigma_ehmm(hmm);
%%% alpha - one per order
hmm = updateAlpha_ehmm(hmm);
else
if ~isfield(train,'distribution') || strcmp(train.distribution,'Gaussian')
%%% sigma - channel x channel coefficients
hmm = updateSigma(hmm);
end
%%% alpha - one per order
hmm = updateAlpha(hmm);
end
if isfield(train,'distribution') && strcmp(train.distribution,'logistic')
if train.logisticYdim>1
for k = 1:K
hmm.state(k).alpha.Gam_rate = ...
repmat(hmm.state(k).alpha.Gam_rate(~regressed,end),1,train.logisticYdim);
end
end
end
elseif pcapred
for k = 1:K
hmm.state(k).beta.Gam_shape = hmm.state(k).prior.beta.Gam_shape + 0.5 * ndim;
hmm.state(k).beta.Gam_rate = hmm.state(k).prior.beta.Gam_rate + ...
sum(hmm.state(k).W.Mu_W.^2);
end
end
end
%
% function hmm = initW_ehmm(hmm,XX,residuals,Gamma,Sind)
% % all at once
%
% K = size(Gamma,2);
% np = size(XX,2); ndim = size(residuals,2);
% Gamma = [Gamma (K-sum(Gamma,2)) ];
% X = zeros(size(XX,1),np * (K+1));
% Xs = zeros(size(XX,1),np * (K+1));
%
% for k = 1:K+1
% X(:,(1:np) + (k-1)*np) = bsxfun(@times, XX, Gamma(:,k));
% Xs(:,(1:np) + (k-1)*np) = bsxfun(@times, XX, sqrt(Gamma(:,k)));
% end
% gram = Xs' * Xs;
% for n = 1:ndim
% Sind_all = [];
% for k = 1:K+1, Sind_all = [Sind_all; Sind(:,n)]; end
% Sind_all = Sind_all == 1;
% iS_W = gram(Sind_all,Sind_all);
% iS_W = (iS_W + iS_W') / 2 + 1e-6 * eye(size(iS_W,1));
% S_W = inv(iS_W);
% Mu_W = S_W * X(:,Sind_all)' * residuals(:,n);
% hmm.state_shared(n).iS_W = zeros(size(gram));
% hmm.state_shared(n).S_W = zeros(size(gram));
% hmm.state_shared(n).Mu_W = zeros(size(X,2),1);
% hmm.state_shared(n).iS_W(Sind_all,Sind_all) = iS_W;
% hmm.state_shared(n).S_W(Sind_all,Sind_all) = S_W;
% hmm.state_shared(n).Mu_W(Sind_all) = Mu_W;
% end
% for k = 1:K+1
% for n = 1:ndim
% ind = (1:np) + (k-1)*np;
% hmm.state(k).W.Mu_W(:,n) = hmm.state_shared(n).Mu_W(ind);
% hmm.state(k).W.iS_W(n,:,:) = hmm.state_shared(n).iS_W(ind,ind);
% hmm.state(k).W.S_W(n,:,:) = hmm.state_shared(n).S_W(ind,ind);
% end
% end
%
% end
% options:
% baseline W calculated during initial baseline time : bad if you intend it to
% reestimate Gamma
% baseline over the entire data set,
% baseline over the entire data set, then discount from X, and then set to 0.
% I think more appropriate as the priors of the states will be pushing
% towards baseline, and not towards zero. This way we will pick up
% whatever is above and beyond
function ehmm = initW_ehmm(ehmm,XX,XXGXX,residuals,Gamma,Sind)
np = size(XX,2); ndim = size(residuals,2); K = size(Gamma,2);
ehmm.state_shared = struct();
for n = 1:ndim
ehmm.state_shared(n).iS_W = zeros((K+1)*np,(K+1)*np);
ehmm.state_shared(n).S_W = zeros((K+1)*np,(K+1)*np);
ehmm.state_shared(n).Mu_W = zeros((K+1)*np,1);
end
for k = 1:K
ehmm.state(k).W.iS_W = zeros(ndim,np,np);
ehmm.state(k).W.S_W = zeros(ndim,np,np);
ehmm.state(k).W.Mu_W = zeros(np,ndim);
end
% baseline
if isfield(ehmm.train,'ehmm_baseline_w')
ehmm.state(K+1).W.Mu_W = ehmm.train.ehmm_baseline_w;
ehmm.train = rmfield(ehmm.train,'ehmm_baseline_w');
elseif isfield(ehmm.train,'ehmm_baseline_data')
ehmm.state(K+1).W = computeBaseline(ehmm.train);
ehmm.train = rmfield(ehmm.train,'ehmm_baseline_data');
else
ehmm.state(K+1).W.Mu_W = zeros(np,ndim);
ehmm.state(K+1).W.iS_W = zeros(ndim,np,np);
ehmm.state(K+1).W.S_W = zeros(ndim,np,np);
lambda = ehmm.train.ehmm_regularisation_baseline;
gram = (XX(:,Sind)' * XX(:,Sind));
gram = (gram + gram') / 2 ;
gram = gram + lambda * eye(size(gram,2));
igram = inv(gram);
ehmm.state(K+1).W.Mu_W = igram * (XX' * residuals);
for n = 1:ndim
ehmm.state(K+1).W.iS_W(n,Sind,Sind) = gram;
ehmm.state(K+1).W.S_W(n,Sind,Sind) = igram;
end
end
if ndim == 1
ehmm.state(K+1).W.iS_W = squeeze(ehmm.state(K+1).W.iS_W);
ehmm.state(K+1).W.S_W = squeeze(ehmm.state(K+1).W.S_W);
end
Sind_all = false(np*(K+1),1); Sind_all(np*K + (1:np)) = true;
for n = 1:ndim
ehmm.state_shared(n).Mu_W(Sind_all) = ehmm.state(K+1).W.Mu_W(:,n);
end
% states
ehmm = updateW_ehmm(ehmm,Gamma,residuals,XX,XXGXX,1,...
ehmm.train.ehmm_regularisation_baseline); % using same regularisation
end
function hmm = initW_hmm(hmm,XX,XXGXX,residuals,Gamma)
ndim = size(residuals,2);
if isfield(hmm.train,'B'), B = hmm.train.B; Q = size(B,2);
else Q = ndim;
end
pcapred = hmm.train.pcapred>0;
p = hmm.train.lowrank; do_HMM_pca = (p > 0);
setstateoptions;
K = size(Gamma,2);
for k = 1:K
if pcapred, npred = hmm.train.pcapred;
else npred = Q*length(orders);
end
hmm.state(k).W = struct('Mu_W',[],'S_W',[]);
if do_HMM_pca || order>0 || ~train.zeromean || ...
(isfield(train,'distribution') && strcmp(train.distribution,'logistic'))
if train.uniqueAR || ndim==1 % it is assumed that order>0 and cov matrix is diagonal
XY = zeros(npred+(~train.zeromean),1);
XGX = zeros(npred+(~train.zeromean));
for n = 1:ndim
ind = n:ndim:size(XX,2);
XGX = XGX + XXGXX{k}(ind,ind);
XY = XY + (XX(:,ind)' .* repmat(Gamma(:,k)',length(ind),1)) * residuals(:,n);
end
if ~isempty(train.prior)
hmm.state(k).W.S_W = inv(train.prior.iS + XGX);
hmm.state(k).W.Mu_W = hmm.state(k).W.S_W * (XY + train.prior.iSMu); % order by 1
else
%hmm.state(k).W.S_W = inv(0.1 * mean(trace(XGX)) * eye(length(orders)) + XGX);
hmm.state(k).W.S_W = inv(0.01 * eye(npred+(~train.zeromean)) + XGX);
hmm.state(k).W.Mu_W = hmm.state(k).W.S_W * XY; % order by 1
end
elseif do_HMM_pca
weights = Gamma(:,k); weights(weights==0) = eps;
hmm.state(k).W.Mu_W = pca(XX,'NumComponents',p,'Weights',weights,'Centered',false);
Xpca = XX * hmm.state(k).W.Mu_W;
hmm.state(k).W.iS_W = zeros(ndim,p,p);
hmm.state(k).W.S_W = zeros(ndim,p,p);
for n = 1:ndim
hmm.state(k).W.iS_W(n,:,:) = (Xpca' .* repmat(Gamma(:,k)',p,1)) * Xpca + 0.01*eye(p);
hmm.state(k).W.S_W(n,:,:) = diag(1 ./ diag(permute(hmm.state(k).W.iS_W(n,:,:),[2 3 1])));
end
elseif strcmp(train.covtype,'uniquediag') || strcmp(train.covtype,'diag') || ...
strcmp(train.covtype,'shareddiag') || ...
(isfield(train,'distribution') && strcmp(train.distribution,'logistic'))
hmm.state(k).W.Mu_W = zeros((~train.zeromean)+npred,ndim);
hmm.state(k).W.iS_W = zeros(ndim,(~train.zeromean)+npred,(~train.zeromean)+npred);
hmm.state(k).W.S_W = zeros(ndim,(~train.zeromean)+npred,(~train.zeromean)+npred);
for n = 1:ndim
ndim_n = sum(Sind(:,n)>0);
if ndim_n==0, continue; end
hmm.state(k).W.iS_W(n,Sind(:,n),Sind(:,n)) = ...
XXGXX{k}(Sind(:,n),Sind(:,n)) + 0.01*eye(sum(Sind(:,n))) ;
hmm.state(k).W.S_W(n,Sind(:,n),Sind(:,n)) = ...
inv(permute(hmm.state(k).W.iS_W(n,Sind(:,n),Sind(:,n)),[2 3 1]));
hmm.state(k).W.Mu_W(Sind(:,n),n) = ...
(( permute(hmm.state(k).W.S_W(n,Sind(:,n),Sind(:,n)),[2 3 1]) ...
* XX(:,Sind(:,n))') .* repmat(Gamma(:,k)',sum(Sind(:,n)),1)) * residuals(:,n);
end
else
if all(S(:))==1
gram = kron(XXGXX{k},eye(ndim));
hmm.state(k).W.iS_W = gram + 0.01*eye(size(gram,1));
hmm.state(k).W.S_W = inv( hmm.state(k).W.iS_W );
hmm.state(k).W.Mu_W = (( XXGXX{k} \ XX' ) .* ...
repmat(Gamma(:,k)',(~train.zeromean)+npred,1)) * residuals;
else
regressed = sum(S,1)>0; % dependent variables, Y
index_iv = sum(S,2)>0; % independent variables, X
% note that XXGXX is invalid if any S==0:
hmm.state(k).W.iS_W = zeros(length(S(:)));
hmm.state(k).W.S_W = zeros(length(S(:)));
temp1 = bsxfun(@times,XX(:,index_iv),Gamma(:,k));
gram = kron(eye(sum(regressed)),temp1'*XX(:,index_iv));
hmm.state(k).W.iS_W(S(:),S(:)) = gram + 0.01*eye(size(gram,1));
hmm.state(k).W.S_W(S(:),S(:)) = inv( hmm.state(k).W.iS_W(S(:),S(:)) );
hmm.state(k).W.Mu_W = zeros(size(S));
% intialise to OLS estimate:
hmm.state(k).W.Mu_W(S) = ...
pinv(residuals(Gamma(:,k)>0.5,index_iv))*residuals(Gamma(:,k)>0.5,regressed);
end
end
end
end
end