https://github.com/OHBA-analysis/HMM-MAR
Raw File
Tip revision: 94b4f6ef8e7d15849a7714a0e5d360a8787f57b2 authored by dvidaurre@gmail.com on 21 April 2017, 14:47:34 UTC
Cleaned up and extended the data format accepted by the toolbox
Tip revision: 94b4f6e
hmmmar.m
function [hmm, Gamma, Xi, vpath, GammaInit, residuals, fehist, feterms, rho] = ...
    hmmmar (data,T,options)
% Main function to train the HMM-MAR model, compute the Viterbi path and,
% if requested, obtain the cross-validated sum of prediction quadratic errors.
%
% INPUT
% data          observations, either a struct with X (time series) and C (classes, optional)
%                             or just a matrix containing the time series
% T             length of series
% options       structure with the training options - see documentation in 
%                       https://github.com/OHBA-analysis/HMM-MAR/wiki
%
% OUTPUT
% hmm           estimated HMMMAR model
% Gamma         Time courses of the states probabilities given data
% Xi            joint probability of past and future states conditioned on data
% vpath         most likely state path of hard assignments
% GammaInit     Time courses used after initialisation.
% residuals     if the model is trained on the residuals, the value of those
% fehist        historic of the free energies across iterations
%
% Author: Diego Vidaurre, OHBA, University of Oxford (2015)

if iscell(T)
    if size(T,1)==1, T = T'; end
    for i = 1:length(T)
        if size(T{i},1)==1, T{i} = T{i}'; end
    end
    N = numel(cell2mat(T));
else
    N = length(T);
end

stochastic_learn = isfield(options,'BIGNbatch') && ...
    (options.BIGNbatch < N && options.BIGNbatch > 0);
options = checkspelling(options);

% do some data checking and preparation
if xor(iscell(data),iscell(T)), error('X and T must be cells, either both or none of them.'); end
if stochastic_learn, % data is a cell, either with strings or with matrices
    if ~iscell(data)
       dat = cell(N,1); TT = cell(N,1);
       for i=1:N
          t = 1:T(i);
          dat{i} = data(t,:); TT{i} = T(i);
          try data(t,:) = []; 
          catch, error('The dimension of data does not correspond to T');
          end
       end
       if ~isempty(data), 
           error('The dimension of data does not correspond to T');
       end 
       data = dat; T = TT; clear dat TT
    end
    options = checksoptions(options,T);
else % data can be a cell or a matrix 
    if iscell(T)
        for i = 1:length(T)
            if size(T{i},1)==1, T{i} = T{i}'; end
        end
        if size(T,1)==1, T = T'; end
        T = cell2mat(T);
    end
    if iscell(data) 
        if size(data,1)==1, data = data'; end
        if iscellstr(data) % if comes in file name format
            dfilenames = data; t0 = 0;
            for i=1:N
                if ~isempty(strfind(dfilenames{i},'.mat')), load(dfilenames{i},'X');
                else X = dlmread(dfilenames{i});
                end
                if i==1, data = zeros(sum(T),size(X,2)); end
                try, data(t0 + (1:size(X,1)),:) = X; t0 = t0 + size(X,1); 
                catch, error('The dimension of data does not correspond to T'); 
                end
            end
            if t0~=sum(T), error('The dimension of data does not correspond to T'); end
        else
            try, data = cell2mat(data);
            catch, error('Subjects do not have the same number of channels');
            end
        end
    end
    [options,data] = checkoptions(options,data,T,0);
    if options.standardise == 1
        for i = 1:N
            t = (1:T(i)) + sum(T(1:i-1));
            data.X(t,:) = data.X(t,:) - repmat(mean(data.X(t,:)),length(t),1);
            sdx = std(data.X(t,:));
            if any(sdx==0)
                error('At least one of the trials/segments/subjects has variance equal to zero');
            end
            data.X(t,:) = data.X(t,:) ./ repmat(sdx,length(t),1);
        end
    else
        for i = 1:N
            t = (1:T(i)) + sum(T(1:i-1));
            if any(std(data.X(t,:))==0)
                error('At least one of the trials/segments/subjects has variance equal to zero');
            end
        end
    end
end

ver = version('-release');
oldMatlab = ~isempty(strfind(ver,'2010')) || ~isempty(strfind(ver,'2010')) ...
    || ~isempty(strfind(ver,'2011')) || ~isempty(strfind(ver,'2012'));

% set the matlab parallel computing environment
if options.useParallel==1 && usejava('jvm')
    if oldMatlab
        if matlabpool('size')==0
            matlabpool
        end
    else
        gcp;
    end
end

gatherStats = 0;
if isfield(options,'DirStats')
    profile on
    gatherStats = 1; 
    DirStats = options.DirStats;
    options = rmfield(options,'DirStats'); 
    % to avoid recurrent calls to hmmmar to do the same
end

if stochastic_learn
    
    % get PCA loadings 
    if length(options.pca) > 1 || options.pca > 0 
        if ~isfield(options,'A')
            options.A = highdim_pca(data,T,options.pca,...
                options.embeddedlags,options.standardise,options.onpower);
        end
        options.ndim = size(options.A,2);
        options.S = ones(options.ndim);
        orders = formorders(options.order,options.orderoffset,options.timelag,options.exptimelag); 
        options.Sind = formindexes(orders,options.S);
    end
    if options.pcamar > 0 && ~isfield(options,'B')
        options.B = pcamar_decomp(data,T,options);
    end
    if options.pcapred > 0 && ~isfield(options,'V')
        options.V = pcapred_decomp(data,T,options);
    end
    
    if isempty(options.Gamma) && isempty(options.hmm)
        [hmm,info] = hmmsinit(data,T,options);
        GammaInit = []; 
    elseif isempty(options.Gamma) && ~isempty(options.hmm)
        hmm = versCompatibilityFix(options.hmm);
        GammaInit = [];
        [hmm,info] = hmmsinith(data,T,options,hmm);
    else % ~isempty(options.Gamma)
        GammaInit = options.Gamma;
        options = rmfield(options,'Gamma');
        [hmm,info] = hmmsinitg(data,T,options,GammaInit);
    end
    if options.BIGcyc>1
        [hmm,fehist,feterms,rho] = hmmstrain(data,T,hmm,info,options);
    else
        fehist = []; feterms = []; rho = [];
    end
    Gamma = []; Xi = []; vpath = []; residuals = [];
    if options.BIGcomputeGamma && nargout >= 2
       [Gamma,Xi] = hmmdecode(data,T,hmm,0,[],[]); 
    end
    if options.BIGdecodeGamma && nargout >= 4
       vpath = hmmdecode(data,T,hmm,1,[],[]); 
    end
    
else
    
    % Hilbert envelope
    if options.onpower
       data = rawsignal2power(data,T); 
    end
    % Embedding
    if length(options.embeddedlags) > 1  
        [data,T] = embeddata(data,T,options.embeddedlags);
    end
    % PCA transform
    if length(options.pca) > 1 || options.pca > 0  
        if isfield(options,'A')
            data.X = data.X - repmat(mean(data.X),mean(data.X,1),1);
            data.X = data.X * options.A; 
        else
            [options.A,data.X] = highdim_pca(data.X,T,options.pca,0,0,0);
        end
        if options.standardise_pc == 1
            for i = 1:N
                t = (1:T(i)) + sum(T(1:i-1));
                data.X(t,:) = data.X(t,:) - repmat(mean(data.X(t,:)),length(t),1);
                data.X(t,:) = data.X(t,:) ./ repmat(std(data.X(t,:)),length(t),1);
            end
        end
        options.ndim = size(options.A,2);
        options.S = ones(options.ndim);
        orders = formorders(options.order,options.orderoffset,options.timelag,options.exptimelag);
        options.Sind = formindexes(orders,options.S);
    end
    if options.pcamar > 0 && ~isfield(options,'B')
        options.B = pcamar_decomp(data,T,options);
    end
    if options.pcapred > 0 && ~isfield(options,'V')
        options.V = pcapred_decomp(data,T,options);
    end    
    options.ndim = size(data.X,2);

    if isempty(options.Gamma) && isempty(options.hmm)
        if options.K > 1
            Sind = options.Sind;
            if options.initrep>0 && ...
                    (strcmpi(options.inittype,'HMM-MAR') || strcmpi(options.inittype,'HMMMAR'))
                options.Gamma = hmmmar_init(data,T,options,Sind);
            elseif options.initrep>0 &&  strcmpi(options.inittype,'EM')
                error('EM init is deprecated; use HMM-MAR initialisation instead')
                %options.nu = sum(T)/200;
                %options.Gamma = em_init(data,T,options,Sind);
            elseif options.initrep>0 && strcmpi(options.inittype,'GMM')
                error('GMM init is deprecated; use HMM-MAR initialisation instead')
                %options.Gamma = gmm_init(data,T,options);
            elseif strcmpi(options.inittype,'random') || options.initrep==0
                options.Gamma = initGamma_random(T-options.maxorder,options.K,options.DirichletDiag);
            else
                error('Unknown init method')
            end
        else
            options.Gamma = ones(sum(T)-length(T)*options.maxorder,1);
        end
        GammaInit = options.Gamma;
        options = rmfield(options,'Gamma');
    elseif isempty(options.Gamma) && ~isempty(options.hmm)
        GammaInit = [];
    else % ~isempty(options.Gamma)
        GammaInit = options.Gamma;
        options = rmfield(options,'Gamma');
    end

    % If initialization Gamma has fewer states than options.K, put those states back in
    % and renormalize
    if size(GammaInit,2) < options.K 
        % States were knocked out, but semisupervised in use, so put them back
        GammaInit = [GammaInit 0.0001*rand(size(GammaInit,1),options.K-size(GammaInit,2))];
        GammaInit = bsxfun(@rdivide,GammaInit,sum(GammaInit,2));
    end

    fehist = Inf;
    if isempty(options.hmm) % Initialisation of the hmm
        hmm_wr = struct('train',struct());
        hmm_wr.K = options.K;
        hmm_wr.train = options;
        %if options.whitening, hmm_wr.train.A = A; hmm_wr.train.iA = iA;  end
        hmm_wr = hmmhsinit(hmm_wr);
        [hmm_wr,residuals_wr] = obsinit(data,T,hmm_wr,GammaInit);
    else % using a warm restart from a previous run
        hmm_wr = versCompatibilityFix(options.hmm);
        options = rmfield(options,'hmm');
        hmm_wr.train = options;
        residuals_wr = getresiduals(data.X,T,hmm_wr.train.Sind,hmm_wr.train.maxorder,hmm_wr.train.order,...
            hmm_wr.train.orderoffset,hmm_wr.train.timelag,hmm_wr.train.exptimelag,hmm_wr.train.zeromean);
    end
    
    for it=1:options.repetitions
        hmm0 = hmm_wr;
        residuals0 = residuals_wr;
        [hmm0,Gamma0,Xi0,fehist0] = hmmtrain(data,T,hmm0,GammaInit,residuals0,options.fehist);
        if options.updateGamma==1 && fehist0(end)<fehist(end),
            fehist = fehist0; hmm = hmm0;
            residuals = residuals0; Gamma = Gamma0; Xi = Xi0;
        elseif options.updateGamma==0,
            fehist = []; hmm = hmm0;
            residuals = []; Gamma = GammaInit; Xi = [];
        end
    end
    
    if options.decodeGamma && nargout >= 4
        vpath = hmmdecode(data.X,T,hmm,1,residuals);
        if ~options.keepS_W
            for i=1:hmm.K
                hmm.state(i).W.S_W = [];
            end
        end
    else
        vpath = ones(size(Gamma,1),1);
    end
    hmm.train = rmfield(hmm.train,'Sind');
    
    feterms = []; rho = [];
    
end

if (all(max(Gamma)<0.6) && all(min(Gamma)>(1/hmm.train.K/2)))
    warning(['It seems that the inference was trapped in a local minima; ' ...
        'you might want to increment DirichletDiag and rerun'])
end

if gatherStats==1
    hmm.train.DirStats = DirStats; 
    profile off
    profsave(profile('info'),hmm.train.DirStats)
end

if options.pca > 0
    hmm.train.A = options.A; 
end
    
end
back to top