https://github.com/hasselmonians/mle_rhythmicity
Tip revision: 7c50ba73fa34dbe2586fb0f9e0204bd1e28f3209 authored by jrclimer on 11 July 2016, 16:29:52 UTC
Resolves #7
Resolves #7
Tip revision: 7c50ba7
rhythmicity_covar.m
function [ params, confidence, stats, everything ] = rhythmicity_covar( data, session_duration, covars, varargin )
% RHYTHMICITY_COVAR Analyze rhythmicity as a linear function of covariates
% Analyze rhythmicity allowing parameters (e.g. frequency) to vary
% linearly with an arbitrarity number of covariates (e.g. speed).
% INPUT:
% data: Can either be spike timestamps or a cell array of lags
% session_duration: Duration of session
% covars: Either a nXm matrix, where n is the number of spikes and m is
% the number of covariates, or a cell array corresponding to each lag
%
% PARAMETERS:
% In addition to the parameters in mle_rhythmicity:
%
% plotit (true): If true, plots histogram and distribution estimates. Can
% also be [false true], in which case we don't plot the static fit.
% plot_axis (1): The axis (singular) to plot. To plot additional axes,
% save output everything and see plot_rhythmicity_covar
% $_covar (0): $ is replaced by a parameter name (a, tau, b, c, f, s). A
% vector of the covariates that will modulate that parameter, with 0
% corresponding with a constant and higher integers corresponding
% with the columns of covars.
% post_hocs (struct): struct with fields a, tau, b, c, f or s. Each field
% contains the indices of the covariates for the leave-one-out
% analysis. For leave multiple out analysis, use a cell array. For
% example, to test f against columns 1 and 2+3, use
% struct('f',{{1, [2 3]}}).
% display ('verbose'): If 'verbose', prints information about the ongoing
% analysis
% covar_labels ({''}) - Labels for the covariate axis
%
% Copywrite (c) 2015,2016 Trustees of Boston University
% All rights reserved.
%
% This file is part of mle_rhythmicity revision 2.0. The last committed
% version of the previous revision is the SHA starting with 93862ac...
%
% This code has been freely distributed by the authors under the BSD
% license (http://opensource.org/licenses/BSD2-Clause). If used or
% modified, we would appreciate if you cited our paper:
%
% Climer JR, DiTullio R, Newman EL, Hasselmo ME, Eden UT. (2014),
% Examination of rhythmicity of extracellularly recorded neurons in the
% entorhinal cortex. Hippocampus, 25:460-473. doi: 10.1002/hipo.22383.
%% Parse input
PARAMS = {'tau','b','c','f','s','r'};% List of parameters - used for easier coding
display = [];
ip = inputParser;
warning('off','stats:mle:EvalLimit');
% mle_rhythmicity params
ip.addParamValue('max_lag', 0.6);
ip.addParamValue('epochs',[]);
ip.addParamValue('epochmode','lead');
ip.addParamValue('noskip',false);
ip.addParamValue('f_range',[1 13]);
ip.addParamValue('alpha',0.05);
ip.addParamValue('t_axis',[ ]);
ip.addParamValue('runs',3);
% Covar specific params
ip.addParamValue('a_covar',NaN);
ip.addParamValue('post_hocs',struct);
ip.addParamValue('display','verbose');
ip.addParamValue('plotit',[true true]);
ip.addParamValue('plot_axis',1);
ip.addParamValue('covar_labels',{''});
% Pull out covariates
for i=1:numel(PARAMS)
ip.addParamValue([PARAMS{i} '_covar'],0);
end
% Parse
ip.parse(varargin{:});
% Dump everything other than covariates to workspace
i = fields(ip.Results);
for i = i(cellfun(@(x)isempty(findstr('_covar',x)),i))'
eval([i{1} '=ip.Results.' i{1} ';']);
end
% If noskip, drop s from PARAMS
if noskip
PARAMS = PARAMS(~ismember(PARAMS,'s'));
end
cov = struct;% Struct of covariate indicies
if isscalar(plotit)
plotit = repmat(plotit,[1 2]);
end
% Load structure with covariate indicies
for i=1:numel(PARAMS)
cov.(PARAMS{i}) = sort(unique(ip.Results.([PARAMS{i} '_covar'])));
if ~isequal(cov.(PARAMS{i}),ip.Results.([PARAMS{i} '_covar']))
warning([PARAMS{i} '_covar not unique/sorted. Output may not be shaped the same as the input.']);
end
end
if ~isequal(ip.Results.a_covar,0)
cov.r = ip.Results.a_covar;
end
for i=find(~ismember(PARAMS,fields(post_hocs)))
post_hocs.(PARAMS{i}) = [];
end
for i=1:numel(PARAMS)
if ~iscell(post_hocs.(PARAMS{i}))
if isempty(post_hocs.(PARAMS{i})), post_hocs.(PARAMS{i}) = {};
else post_hocs.(PARAMS{i}) = {post_hocs.(PARAMS{i})}; end;
end
end
PARAMS = [PARAMS {'a'}];
for i=1:numel(PARAMS)
for j=1:numel(post_hocs.(PARAMS{i}))
if ~iscell(post_hocs.(PARAMS{i}))
post_hocs.(PARAMS{i})= {post_hocs.(PARAMS{i})};
end
if ~isempty(post_hocs.(PARAMS{i}){j})&&~isequal(post_hocs.(PARAMS{i}){j},sort(unique(post_hocs.(PARAMS{i}){j})))
post_hocs.(PARAMS{i}){j}=sort(unique(post_hocs.(PARAMS{i}){j}));
warning(['post_hocs.' PARAMS{i} ' not unique/sorted. Output may not be the same shape as the input']);
end
end
end
PARAMS = PARAMS(1:end-1);
if ismember('a',fields(post_hocs))
post_hocs.r = post_hocs.a;
end
% Handle r-a conversion
if ~isnan(ip.Results.a_covar)
r_covar = ip.Results.a_covar;
end
if ismember('a',fields(cov))
cov.r = cov.a;
elseif isnan(cov.r)
cov.r = 0;
end
if all(cellfun(@(x)isequal(cov.(x),0)||any(isnan(cov.(x))),PARAMS))
exception = MException('mle_rhythmicity:rhythmicity_covar:NoCov',...
['No covariates defined in input. For non-covariate rhythmicity '...
'analysis, please see mle_rhythmicity. See ']);
throw(exception);
end
%%
if ~all(cellfun(@(x)ismember(0,cov.(x)),PARAMS))
warning('Covarying parameters were used without a constant offset, if this was not intended it may results in bad fits and covariates cannot be shifted back from zero-mean.');
end
if isequal(display,'verbose')
fprintf('\nrhythmicity_covar on %i events. Covariates are:',numel(data));
for i=1:numel(PARAMS)
fprintf('\n\t%s',PARAMS{i});
if isequal(cov.(PARAMS{i}),0)
fprintf(' is constant');
else
fprintf(' as a linear function of column');
if sum(cov.(PARAMS{i})~=0)>1
fprintf(['s ' strjoin(arrayfun(@num2str,cov.(PARAMS{i})(cov.(PARAMS{i})~=0),'UniformOutput',false),',')]);
else
fprintf(' %i',cov.(PARAMS{i})(cov.(PARAMS{i})~=0));
end
fprintf(' and ');
if ~ismember(0,cov.(PARAMS{i}))
fprintf('NO ');
end
fprintf('constant.');
end
end
fprintf('\nRunning static fit...')
end
% Static fit
if plotit(1)% Static plotting
if any(cellfun(@(x)isequal(x,'plotit'),varargin))
varargin = varargin([1:find(cellfun(@(x)isequal(x,'plotit'),varargin))-1 find(cellfun(@(x)isequal(x,'plotit'),varargin))+2:end]);
end
pos = get(gca,'position');
HOLD = ishold;
if ~HOLD
cla;
axis off
end
if plotit(2)
subplot('position',pos.*[1 1 19/44 1]);
end
end
[~,~,stats0,everything0] = mle_rhythmicity(data,session_duration,'plotit',plotit,varargin{:});
if plotit(1)&&plotit(2)
subplot('position',pos.*[1 1 19/44 1]+[pos(3)*25/44 0 0 0]);
end
% Pull out needed indices
inds = everything0.inds;
epochs = everything0.epochs;
lags = everything0.lags;
lags_list = everything0.lags_list;
n = numel(lags_list);
if isequal(display,'verbose')
fprintf('done, %i lags, cell is ',numel(lags_list));
if stats0.p_rhythmic>0.05
fprintf('NOT ');
end
fprintf('rhythmic.');
end
%% Format covariates
% Align covars with lags
if ~iscell(covars)
covars_lags = cellfun(@(x)covars(x,:),inds,'UniformOutput',false);
else
covars_lags = covars;
end
covars_list = cat(1,covars_lags{:});
% Shift the covariates to 0 mean - this makes the fit easier
shift = mean(covars_list);
covars_list = covars_list-repmat(shift,[size(covars_list,1) 1]);
covars_list = [ones(size(covars_list,1),1) covars_list];
shift = [0 shift];
% Log-likelihood wrapper function
LL_fun = @(cif_fun,cif_int,varargin)infbnd(sum(log(cif_fun(lags_list,varargin{:}))-log(cif_int(max_lag,varargin{:}))));
% Generate CIFs and CIF integral functions
if noskip
[cif_fun, cif_int] = cif_generator('noskip');
else
[cif_fun, cif_int] = cif_generator('full');
end
% Make fit bounds
lowerbound = [-inf 0 -inf 0 0];
upperbound = [inf 1 inf everything0.phat(5)*2 1];
if ~noskip
lowerbound = [lowerbound 0];
upperbound = [upperbound 1];
end
%%
% Make initial guess & large bounds
x0 = zeros(1,sum(cellfun(@(x)numel(cov.(x)),PARAMS)));
phat = everything0.phat;
phat = [phat(2:end) phat(1)/(1-phat(2))];
% clc
lbnd = -inf(1,numel(x0));
ubnd = inf(1,numel(x0));
for i=1:numel(PARAMS)
lbnd(sum(arrayfun(@(j)numel(cov.(PARAMS{j})),1:i-1))+1)=lowerbound(i);
ubnd(sum(arrayfun(@(j)numel(cov.(PARAMS{j})),1:i-1))+1)=upperbound(i);
end
%
for i=1:numel(PARAMS)
x0(sum(arrayfun(@(j)numel(cov.(PARAMS{j})),1:i-1))+1)=phat(i);
end
%% Initial fits: allow covariates one parameter at a time
if isequal(display,'verbose')
fprintf('\nFinding initial guess...');
end
for k=1:3% Iterate 3 times
for i=find(~cellfun(@(x)isequal(cov.(x),0),PARAMS))% For each non-static parameter
if isinf(upperbound(i))% The parameter is unbounded
x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i})))) = ... Set the initial value
fminsearch(... Use fminsearch (unbounded)
... BEGIN FUN
@(phat)...
-passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),...log-likelihood function
covar_wrapper(... Wrapper for parameters with covariates
[x0(1:sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1))))... Keep other parameters fixed
phat... The current covarying parameter
x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i)))+1:end)... Keep other parameters fixed
]...
,cov,covars_list)... END covar_wrapper
)...
... END FUN
, x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i}))))... x0
, optimset(... options
'TolX',1e-10...
,'display','none'...
));
else % Bounded covariates
x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i})))) = ... Set the initial value
fmincon(... Use fmincon (bounded)
... BEGIN FUN
@(phat)...
-passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:}),...log-likelihood function
covar_wrapper(...Wrapper for parameters with covariates
[x0(1:sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1))))...Keep other parameters fixed
phat...The current covarying parameter
x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i)))+1:end)...Keep other parameters fixed
]...
,cov,covars_list)...END covar_wrapper
)...
... END FUN
, x0(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i}))))... x0
, [covars_list(:,cov.(PARAMS{i})+1);-covars_list(:,cov.(PARAMS{i})+1)] ... A (linear constraints)
, [ones(numel(lags_list),1)*upperbound(i);ones(numel(lags_list),1)*-lowerbound(i)] ... b (linear constraints)
, [] ... Aeq
, [] ... beq
, [] ... ub
, [] ... lb
, [] ... nlcon
, optimset(... options
'algorithm','sqp'...
,'TolX',1e-10...
,'display','none'...
));
end
end
end
%% Final fits
% Make linear constraints for the final fits
A = ...
arrayfun(@(k)... For each*
[...
zeros(2*numel(lags_list),sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:k-1))))...% Zeros for irrelevant parameters
[covars_list(:,cov.(PARAMS{i})+1);-covars_list(:,cov.(PARAMS{i})+1)]...% The covariates for this (k) bounded parameter (+ and -)
zeros(2*numel(lags_list),sum(cellfun(@(x)numel(cov.(x)),PARAMS(k+1:end))))...% Zeros for irrelevant parameters
]...
,find(~cellfun(@(x)isequal(cov.(x),0),PARAMS)&~isinf(upperbound))...*constrained, covarying parameter
,'UniformOutput',false);
A = cat(1,A{:});
B = ...
arrayfun(@(k)...For each*
[ones(numel(lags_list),1)*upperbound(k);ones(numel(lags_list),1)*-lowerbound(k)]...% Bounds for this parameter (+upper, -lower)
,find(~cellfun(@(x)isequal(cov.(x),0),PARAMS)&~isinf(upperbound))...*constrained, covarying parameter
,'UniformOutput',false);
B = cat(1,B{:});
if isequal(display,'verbose')
fprintf('done.\nFinal convergance...');
end
phat = fmincon(...
... BEGIN FUN
@(phat)...
-passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:})...log-likelihood function
,covar_wrapper(phat,cov,covars_list))...Wrapper for parameters with covariates
... END FUN
,x0 ... x0
, A ... A
, B ... B
, [] ... Aeq
, [] ... beq
, lbnd ... lbnd
, ubnd ... ubnd
, [] ... nlcon
, optimset(...
'algorithm','sqp'...
,'TolX',1e-10...
,'display','none'...
));
LL = passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,covar_wrapper(phat,cov,covars_list));
if isequal(display,'verbose')
fprintf('done.');
end
%% Post-hoc tests
post_hoc_results = struct;
if any(~cellfun(@(x)isempty(post_hocs.(x)),PARAMS))
if isequal(display,'verbose')
fprintf('\nRunning post hoc tests...');
end
for i=1:numel(PARAMS)
for k=1:numel(post_hocs.(PARAMS{i}))
j=post_hocs.(PARAMS{i}){k};
if isequal(j,0)||~all(ismember(j,cov.(PARAMS{i})))
warning('Covariates #%s not under %s, skipping.',strjoin(arrayfun(@num2str,j,'UniformOutput',false),', '),PARAMS{i});
else
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))) = struct;
cov2 = cov;
cov2.(PARAMS{i}) = cov2.(PARAMS{i})(~ismember(cov2.(PARAMS{i}),j));
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).phat = ...
mle(...
1,'logpdf',@(~,varargin)...
passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,covar_wrapper([varargin{:}],cov2,covars_list))...
,'start'...
,phat([1:sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))...
sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+find(~ismember(cov.(PARAMS{i}),j))...
(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i)))+1):end]));
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).LL = passall(@(varargin)LL_fun(cif_fun,cif_int,varargin{:})...
,covar_wrapper(post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).phat,cov2,covars_list));
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).deviance = ...
2*(LL-...
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).LL);
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).p = ...
1-chi2cdf(post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).deviance,...
numel(j));
end
end
end
end
%% Shift everything back
covars_list = covars_list+ones(size(lags_list))*shift;
if all(cellfun(@(x)ismember(0,cov.(x)),PARAMS(~cellfun(@(x)isequal(cov.(x),0),PARAMS))))
for i=1:numel(PARAMS)
phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+1) = ...
phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+1) - ...
sum(shift(cov.(PARAMS{i})+1).*...
phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i})))));
for k=1:numel(post_hocs.(PARAMS{i}))
j=post_hocs.(PARAMS{i}){k};
cov2 = cov;
cov2.(PARAMS{i}) = cov2.(PARAMS{i})(~ismember(cov.(PARAMS{i}),j));
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).phat(...
sum(cellfun(@(x)numel(cov2.(x)),PARAMS(1:i-1)))+1 ...
) = ...
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).phat(...
sum(cellfun(@(x)numel(cov2.(x)),PARAMS(1:i-1)))+1 ...
) - ...
sum(shift(cov2.(PARAMS{i})+1).*...
post_hoc_results.(sprintf('%s%s',PARAMS{i},strjoin(arrayfun(@num2str,j,'UniformOutput',false),'_'))).phat(...
sum(cellfun(@(x)numel(cov2.(x)),PARAMS(1:i-1)))+(1:numel(cov2.(PARAMS{i})))));
end
end
end
if isequal(display,'verbose')
fprintf('done.');
end
%% Packaging output
params = struct;
confidence = struct;
everything = struct;
for i=1:numel(PARAMS)% Pull out parameters going with each main parameter
params.(PARAMS{i}) = ...
phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i}))));
end
if isequal(cov.b,0)% We can do a conversion
PARAMS2 = [{'a'} PARAMS(~ismember(PARAMS,'r'))];
params.a = (1-params.b)*params.r;% Add a to params
% Do a conversions
everything.phat = cellfun(@(x)params.(x),PARAMS2,'UniformOutput',false);
everything.phat = cat(2,everything.phat{:});
% Calculate a se
everything.se = (diag(mlecov(phat,1,'logpdf',@(~,varargin)passall(@(varargin)...
LL_fun(cif_fun,cif_int,varargin{:})...
,covar_wrapper([varargin{numel(cov.r)+1:end} [varargin{1:numel(cov.r)}]/(1-params.b)]...
,cov,covars_list)))))';
PARAMS2 = [{'r'} PARAMS(~ismember(PARAMS,'r'))];
for i=1:numel(PARAMS2)
confidence.(PARAMS2{i}) = ...
[-1;1]*everything.se(sum(cellfun(@(x)numel(cov.(x)),PARAMS2(1:i-1)))+(1:numel(cov.(PARAMS2{i}))))+repmat(everything.phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS2(1:i-1)))+(1:numel(cov.(PARAMS2{i})))),[2 1]);
end
else % b is static
warning('Cannot do a conversion when b is not static.');
everything.phat = phat;
% everything.se = diff(ci)/2/norminv(1-alpha);
for i=1:numel(PARAMS)
confidence.(PARAMS{i}) = ...
[-1;1]*everything.se(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i}))))+repmat(everything.phat(sum(cellfun(@(x)numel(cov.(x)),PARAMS(1:i-1)))+(1:numel(cov.(PARAMS{i})))),[2 1]);
end
end
% Copy to everything
everything.ids = everything0.inds;
everything.epochs = everything0.epochs;
everything.phat0 = everything0.phat;
everything.lags = everything0.lags;
everything.lags_list = lags_list;
everything.covars_list = covars_list;
everything.max_lag = max_lag;
everything.noskip = noskip;
everything.session_duration = session_duration;
everything.LL0 = everything0.LL;
everything.cov = cov;
everything.shift = shift;
everything.shifted_back = ...
all(cellfun(@(x)ismember(0,cov.(x)),PARAMS(~cellfun(@(x)isequal(cov.(x),0),PARAMS))));
% Calculate LL
everything.LL = LL;
everything.covar_labels = covar_labels;
% Do stats
stats.deviance = 2*(everything.LL-everything.LL0);
stats.df = numel(everything.phat)-numel(everything.phat0);
stats.p = 1-chi2cdf(stats.deviance,stats.df);
stats.post_hoc = post_hoc_results;
%% Plotit
if HOLD
hold on;
else
hold off;
end
save working.mat;
if plotit(2)
plot_rhythmicity_covar( everything, plot_axis, 'covar_label', covar_labels{plot_axis} );
end
end