https://github.com/hasselmonians/mle_rhythmicity
Tip revision: 7c50ba73fa34dbe2586fb0f9e0204bd1e28f3209 authored by jrclimer on 11 July 2016, 16:29:52 UTC
Resolves #7
Resolves #7
Tip revision: 7c50ba7
mle_rhythmicity.m
function [params, confidence, stats, everything] = mle_rhythmicity( data, session_duration, varargin )
%MLE_RHYTHMICITY Maximum likelihood estimation for rhythmicity parameters
% Using a conditional intensity function for the distribution of lags in
% the autocorrelogram window, the set of parameters that maximize the
% likelihood of the data are identified.
%
% [...] = mle_rhythmicitiy(data, session_duration, [PARAMETER], [value]...);
%
% INPUT:
% data: Can either be spike timestamps or a cell array of lags
% session_duration: Duration of session
%
% PARAMETERS
% max_lag (0.6): Examination window (seconds)
% plotit (true): If true, plots histogram and distribution estimates
% noskip (false): If true, omits skipping from the model
% f_range ([1 13]): Range of possible frequencies for the rhythmicity.
% alpha (0.05): Cutoff for confidence intervals
% t_axis (linspace(0,maxlag,61)): Axis for plotting
% runs (3): Number of runs, takes the best answer of all.
%
% OUTPUT
% params: A struct containing the parameter estimates
% a - Relative magnitude of the rhythmicity, ranges from 0
% (no rhythm) to 1 (maximally rhythmic)
% tau (log10(sec)) - Exponential falloff of the distrubtion
% b - Baseline
% c (log10(sec)) - Exponential falloff of the rhythmicity magnitude
% f (Hz) - Frequency of the rhythmicity
% s - Relative heights of the secondary peaks. Not present if
% noskip=true
% confidence: Struct containing confidence intervals of the parameters
% stats: Struct containing
% deviance_rhythmic - Deviance for the rhythmicity test. If
% noskip=false, this has 4 degrees of freedom, otherwise 3.
% p_rhythmic - Significance of chi-squared test for rhythmicity
% deviance_skipping - Deviance for the skipping test, with 1 degree
% of freedom. Not present if noskip=true
% p_rhythmic - Significance of chi-squared test for skipping. Not
% present if noskip=true.
% everything: Struct containing
% LL - Log likelihood for the fit
% LL_flat - Log likelihood for the non-rhythmic fit
% se - Standard errors for the parameters
% phat - Parameter estimate as a vector
% lags - Cell array of all the lags
% lags_list - list of all the lags
% max_lag - Maximum lag
% noskip - Whether noskip=true
% phat_flat - Parameter estimates for the non-rhythmic fit as a
% vector
% session_duration - Duration of the session
% LL_noskip - Log likelihood of the non-skipping fit. Not present if
% noskip=true.
% phat_noskip - Parameter estimates for the non-skipping fit as a
% vector. Not present if noskip=true.
% se_noskip - Standard errors for non-skipping parameters as a
% vector. Not present if noskip=true.
%
% See also cif_generator, epoch_data, rhythmicity_pdf, rhythmicity_covar
%
% Copyright 2015-2016 Trustees of Boston University
% All rights reserved.
%
% This file is part of mle_rhythmicity revision 2.0. The last committed
% version of the previous revision is the SHA starting with 93862ac...
%
% This code has been freely distributed by the authors under the BSD
% license (http://opensource.org/licenses/BSD2-Clause). If used or
% modified, we would appreciate if you cited our papers:
%
% Climer JR, DiTullio R, Newman EL, Hasselmo ME, Eden UT. (2014),
% Examination of rhythmicity of extracellularly recorded neurons in the
% entorhinal cortex. Hippocampus, 25:460-473. doi: 10.1002/hipo.22383.
%
% Hinman et al., Multiple Running Speed Signals in Medial Entorhinal
% Cortex, Neuron (2016). http://dx.doi.org/10.1016/j.neuron.2016.06.027
%% Warnings
warning('off','stats:mle:IterLimit');
%% Parse inputs
alpha = 0;
ip = inputParser;
% This allows us to keep inputs if this was called by rhythmicity_covar
temp = dbstack;
if ismember('rhythmicity_covar',{temp.name})
ip.KeepUnmatched = true;
end
ip.addParamValue('max_lag', 0.6);
ip.addParamValue('epochs',[]);
ip.addParamValue('epochmode','lead');
ip.addParamValue('plotit',true);
ip.addParamValue('noskip',false);
ip.addParamValue('f_range',[1 13]);
ip.addParamValue('alpha',0.05);
ip.addParamValue('t_axis',[ ]);
ip.addParamValue('runs',3);
ip.parse(varargin{:});
% Put all parsed inputs into workspace
for j = fields(ip.Results)'
eval([j{1} ' = ip.Results.' j{1} ';']);
end
if runs>1 % If we're doing multiple runs
warning off stats:mlecov:NonPosDefHessian;
% Reformat input for multiple runs
j = true(size(varargin));
if ~ismember('plotit',ip.UsingDefaults)
j(find(cellfun(@(x)isequal(x,'plotit'),varargin))+[0 1]) = false;
end
if ~ismember('runs',ip.UsingDefaults)
j(find(cellfun(@(x)isequal(x,'runs'),varargin))+[0 1]) = false;
end
% Run the first time
[params, confidence, stats, everything] = mle_rhythmicity( data, session_duration, varargin{j}, 'runs', 1, 'plotit', false );
% Resize for multiple runs
params = repmat(params,[runs 1]);
confidence = repmat(confidence,[runs 1]);
stats = repmat(stats,[runs 1]);
everything = repmat(everything,[runs 1]);
% Do other runs
for i=2:runs
[params(i), confidence(i), stats(i), everything(i)] = mle_rhythmicity( data, session_duration, varargin{j}, 'runs', 1, 'plotit', false );
end
% Find the best fit
[~,i] = max([everything.LL]);
% Select everything for output
params = params(i);
confidence = confidence(i);
stats = stats(i);
everything = everything(i);
else % A single run
% Format and epoch the data
if ~iscell(data)
if isempty(epochs)
epochs = [0 session_duration];
end
[ lags, inds ] = epoch_data( data, epochs ,'epochmode', epochmode, 'max_lag', max_lag);
end
lags_list = cat(1,lags{:});
n = numel(lags_list);
%%
% This abstract function takes in two handles for a CIF and its definite
% integral and the parameters for the CIF, and returns the log-likelihood
% of the set of data
LL_fun = @(cif_fun,cif_int,varargin)infbnd(sum(log(cif_fun(lags_list,varargin{:}))-log(cif_int(max_lag,varargin{:}))));
[cif_fun, cif_int] = cif_generator('flat');% Non-rhythmic cif, see cif_generator
[phat_flat, ci_flat] = mle(1,'logpdf',@(~,varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,'start',[-log10(log(1/0.95)/max_lag) 0.1]...% Tau starts to complete 95% of decay by the maximum lag
,'lowerbound',[-inf 0]....
,'upperbound',[inf 1]);
LL_flat = passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),phat_flat);
b = passall(@(varargin)cif_fun(max_lag,varargin{:}),phat_flat);
% Non-skipping rhythm
% First, we use a pure sinusoid to find the best frequency
[cif_fun, cif_int] = cif_generator('pure');% Pure sinusoid, see cif_generator
f = linspace(f_range(1),f_range(2),200);% To start, check many frequencies across available range
[~,j] = max(arrayfun(@(f)LL_fun(cif_fun, cif_int, f),f));% Find the best one
f = f(j);
% Use builtin optimizers to converge to final best frequency
f = mle(1,'logpdf',@(~,varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,'start',f...
,'lowerbound',f_range(1)....
,'upperbound',f_range(2));
%% Lets try to start b at the termination of the flat fit, and fix the
% frequency above
PopulationSize = 25;
% tau, b, c, r
InitialPopulation = [...
unifrnd(-2,1,[PopulationSize,1])...tau
unifrnd(0,0.4,[PopulationSize,1]) ...b
unifrnd(-2,1,[PopulationSize,1])...c
unifrnd(max(f*0.9,f_range(1)),min(f*1.1,f_range(2)),[PopulationSize,1])...f
unifrnd(0.25,0.75,[PopulationSize,1]) ...r
];
[cif_fun, cif_int] = cif_generator('noskip');
% Particle swarm fit
phat_noskip = pso(...
@(phat)-passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),phat)...
,5 ...
,[],[],[],[]...
,[-inf 0 -inf f_range(1) 0]...
,[inf 0.5 inf f_range(2) 1]...
,[] ...
,psooptimset('Display','off','Generations',100,'InitialPopulation',InitialPopulation,'PopulationSize',PopulationSize,'ConstrBoundary','Reflect'...
...,'PlotFcns',{@psoplotswarm}...
)...
);
% Final convergence with mle
phat_noskip = mle(1,'logpdf',@(~,varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,'start',phat_noskip...
,'lowerbound',[-inf 0 -inf f_range(1) 0]...
,'upperbound',[inf 1 inf f_range(2) 1]);
%%
% % Fit using non-skipping distribution
% Convert to A & calculate CIs, log-likelihood
% keyboard
phat_noskip = [(1-phat_noskip(2))*phat_noskip(end) phat_noskip(1:end-1)];
se_noskip = (diag(mlecov(phat_noskip,1,'logpdf',@(~,a,tau,b,c,f)LL_fun(cif_fun,cif_int,tau,b,c,f,(1-b)*a))))';
ci_noskip = [-1;1]*se_noskip*norminv(1-alpha/2)+repmat(phat_noskip,[2 1]);
LL_noskip = passall(@(a,tau,b,c,f)LL_fun(cif_fun,cif_int,tau,b,c,f,a/(1-b)),phat_noskip);
if ~noskip % Add skipping
[cif_fun, cif_int] = cif_generator('full');
phat = mle(1,'logpdf',@(~,varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,'start',[phat_noskip(2:end) 0.05 phat_noskip(1)/(1-phat_noskip(3))]...
,'lowerbound',[-inf 0 -inf f_range(1) 0 0]...
,'upperbound',[inf 1 inf f_range(2) 1 1]);
LL = passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),phat);
if phat_noskip(5)*2<=f_range(2)% If doubling the frequency is still in the frequency range
% Fit using a doubled frequency and high skipping
phat_half = mle(1,'logpdf',@(~,varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,'start',[phat_noskip(2:4) phat_noskip(5)*2 0.95 phat_noskip(1)/(1-phat_noskip(3))]...
,'lowerbound',[-inf 0 -inf f_range(1) 0 0]...
,'upperbound',[inf 1 inf f_range(2) 1 1]);
LL_half = passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),phat_half);
if LL_half>LL % If doubling the frequency is a better fit, replace
phat = phat_half;
end
end
% Convert to A and calculate CIs, log-likelihood
phat = [(1-phat(2))*phat(end) phat(1:end-1)];
se = (diag(mlecov(phat,1,'logpdf',@(~,a,tau,b,c,f,s)LL_fun(cif_fun,cif_int,tau,b,c,f,s,(1-b)*a))))';
ci = [-1;1]*se*norminv(1-alpha/2)+repmat(phat,[2 1]);
LL = passall(@(a,tau,b,c,f,s)LL_fun(cif_fun,cif_int,tau,b,c,f,s,a/(1-b)),phat);
% Test for significant skipping
stats.deviance_skipping = 2*(LL-LL_noskip);
stats.p_skipping = 1-chi2cdf(stats.deviance_skipping,1);
% Test for significant rhythm
stats.deviance_rhythmic = 2*(LL-LL_flat);
stats.p_rhythmic = 1-chi2cdf(stats.deviance_rhythmic,4);
else % noskip=true
% Use noskip fits
phat = phat_noskip;
se = se_noskip;
ci = ci_noskip;
LL = LL_noskip;
% Test for significant rhythm
stats.deviance_rhythmic = 2*(LL-LL_flat);
stats.p_rhythmic = 1-chi2cdf(stats.deviance_rhythmic,3);
end
%% Package output
PARAM = {'a','tau','b','c','f'};
if ~noskip, PARAM = [PARAM {'s'}]; end
for j=1:numel(PARAM)
eval(sprintf('params.%s=phat(%i);',PARAM{j},j));
eval(sprintf('confidence.%s=ci(:,%i);',PARAM{j},j));
end
PARAM = {'inds','epochs','LL','LL_flat','se','phat','lags','lags_list','max_lag','noskip','phat_flat','session_duration'};
if ~noskip, PARAM = [PARAM {'LL_noskip','phat_noskip','se_noskip'}]; end
for j=1:numel(PARAM)
eval(sprintf('everything.%s=%s;',PARAM{j},PARAM{j}));
end
end
%% plotting
if ~ishold
cla
end
if plotit(1)
if isempty(t_axis)
t_axis = linspace(0,max_lag,61);
else
if ~all(diff(t_axis)-(t_axis(2)-t_axis(1))<0.0001)
warning('Custom t_axis must be monotonically increasing. Using default 61 bins')
t_axis = linspace(0,max_lag,61);
end
end
tn = length(t_axis);
b=histc(everything.lags_list,t_axis);
bar(mean([t_axis(1:end-1);t_axis(2:end)]),b(1:end-1),1);
xlim([0 max_lag]);
hold on;
if noskip
[cif_fun, cif_int] = cif_generator('noskip');
else
[cif_fun, cif_int] = cif_generator('full');
end
ps = passall(@(varargin)cif_fun(linspace(0,max_lag,200),varargin{2:end},varargin{1}/(1-varargin{3})),everything.phat)/...
passall(@(varargin)cif_int(max_lag,varargin{2:end},varargin{1}/(1-varargin{3})),everything.phat);
plot(linspace(0,max_lag,200),ps*max_lag*numel(everything.lags_list)/tn,'r','LineWidth',2);
[cif_fun, cif_int] = cif_generator('flat');
ps = passall(@(varargin)cif_fun(linspace(0,max_lag,200),varargin{:}),everything.phat_flat)/...
passall(@(varargin)cif_int(max_lag,varargin{:}),everything.phat_flat);
plot(linspace(0,max_lag,200),ps*max_lag*numel(everything.lags_list)/tn,'c--','LineWidth',2);
ttl = ['\hat{a}' sprintf('=%2.2g, p',everything.phat(1)) '_{rhyth}=' sprintf('%2.2g',stats.p_rhythmic)];
lgnd = {'data','MLE','flat'};
if ~noskip
[cif_fun, cif_int] = cif_generator('noskip');
ps = passall(@(varargin)cif_fun(linspace(0,max_lag,200),varargin{2:end},varargin{1}/(1-varargin{3})),everything.phat_noskip)/...
passall(@(varargin)cif_int(max_lag,varargin{2:end},varargin{1}/(1-varargin{3})),everything.phat_noskip);
plot(linspace(0,max_lag,200),ps*max_lag*numel(everything.lags_list)/tn,'g--');
ttl = [ttl ', s=' sprintf('%2.2g',everything.phat(end)) ', p_{skip}=' sprintf('%2.2g',stats.p_skipping)];
lgnd = [lgnd {'no-skip'}];
end
hold off
legend(lgnd{:});
%clc
title(['$$' ttl '$$'],'Interpreter','latex');
xlabel('Lag (s)');ylabel('Count');
end
end