https://github.com/rany-abend/threat_learning_eLife
Tip revision: 221919cc02c4b5ae00fc21764095b8b4d2a14201 authored by Rany Abend on 19 April 2022, 18:29:47 UTC
Update fit14SCRModels_eLife_V2.m
Update fit14SCRModels_eLife_V2.m
Tip revision: 221919c
fitSCRModels_Elife.m
%% These functions fit the models (variants of Rescorla-Wagner (RW) and hybrid Pearce-Hall (PH) models) to SCR data for Abend et al., 2022 (eLife)
% Below is an example of how one would fit RW models to a dataset:
% For the 8 RW models, a matrix was used to define the presence of
% additional model features. A 1 indicates that the model feature is
% included and a 0 indicates that it is not included
% Column 1 = habituation params
% Column 2 = bayesian learning rate
% Column 3 = learning inertia
% model = [0 0 0;
% 0 0 1;
% 0 1 0;
% 1 0 0;
% 0 1 1;
% 1 0 1;
% 1 1 0;
% 1 1 1];
% A single model was chosen for each run of the fit. The function fitSCRGen
% takes the model information as an input (e.g.: modelinfo=model(1,:) would
% be the simplest RW model, model 1)
% Then use fminsearch to minimize the negative log-likelihood of the fit to
% the model (set options for limits on fminsearch):
% [fitparams, ll] = fminsearch(@(params) fitSCRGen(params, scr, us, csType, vi, modelinfo, n, mu), params, options);
%inputs:
%params: starting values for parameters
%scr: skin conductance response to conditioned stimulus
%us: skin conductance response to unconditioned stimulus
%vi: initial value estimate
%modelinfo: 3 element vector defining which model should be fit
%outputs:
%fitparams: values of the parameters from the particular fit
%ll: log likelihood of fit
%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%% Model Functions %%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Wrapper function for fitSCRfGen, for compatibility with fminsearch :
function [ll] = fitSCRGen(params, scr, us, csType, vi, modelV, n, mu)
[ll, ~] = fitSCRfGen(params, scr, us, csType, vi, modelV, n, mu);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For fitting 8 RW models to acquisition data:
function [ll, v] = fitSCRfGen(params, scr, us, csType, vi, modelV, n, mu)
v = zeros(2, length(scr) + 1);
v(1:2, 1) = vi;
rpe = zeros(length(scr), 1);
for trial = 1 : length(scr)
rpe(trial) = us(trial) - v(csType(trial), trial);
if modelV(2) == 1
lrp1 = params(1)/(trial^1/2);
lrp2 = params(2)/(trial^1/2);
else
lrp1 = params(1);
lrp2 = params(2);
end
if modelV(3) == 1
ftrial = min(trial-1,2);
if trial > 1
srpe = sum(rpe(trial:-1:trial-ftrial));
else
srpe = rpe(trial);
end
else
srpe = rpe(trial);
end
if csType(trial) == 1
v(1, trial+1) = v(1, trial) + lrp1*srpe + mu(trial);
v(2, trial+1) = v(2, trial) + lrp2*srpe + mu(trial);
else
v(1, trial+1) = v(1, trial) + mu(trial);
v(2, trial+1) = v(2, trial) + mu(trial);
end
end
z1 = find(csType == 1);
z2 = find(csType == 2);
if modelV(1) == 1
exp1 = params(3);
exp2 = params(4);
if isnan(n)
vec = [0 0];
n = 2;
else
vec = zeros(n,1)';
end
v(1, z1) = v(1, z1).*exp(-exp1*([vec 1:(length(z1) - n)]));
v(2, z2) = v(2, z2).*exp(-exp2*([vec 1:(length(z2) - n)]));
end
ll = 0;
for trial = 1 : length(scr)
ll = ll + abs(v(csType(trial), trial) - scr(trial)).^2;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Wrapper function for fitSCRfGenExt, for compatibility with fminsearch :
function [ll] = fitSCRGenExt(params, scr, us, csType, vi, modelV, mu)
[ll, ~] = fitSCRfGenExt(params, scr, us, csType, vi, modelV, mu);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For fitting 8 RW models to extinction data:
function [ll, v] = fitSCRfGenExt(params, scr, us, csType, vi, modelV, mu)
v = zeros(2, length(scr) + 1);
v(1:2, 1) = vi;
rpe = zeros(length(scr), 1);
for trial = 1 : length(scr)
rpe(trial) = us(trial) - v(csType(trial), trial);
if modelV(2) == 1
lrp1 = params(1)/(trial^1/2);
lrp2 = params(2)/(trial^1/2);
else
lrp1 = params(1);
lrp2 = params(2);
end
if modelV(3) == 1
ftrial = min(trial-1,1);
if trial > 1
srpe = sum(rpe(trial:-1:trial-ftrial));
else
srpe = rpe(trial);
end
else
srpe = rpe(trial);
end
csn_seq = find(csType == 2);
first_csn = csn_seq(1);
if trial < first_csn || first_csn == 1
if csType(trial) == 1
v(1, trial+1) = v(1, trial) + lrp1*srpe + mu(trial);
v(2, trial+1) = v(2, trial) + mu(trial);
else
v(1, trial+1) = v(1, trial) + mu(trial);
v(2, trial+1) = v(2, trial) + lrp2*srpe + mu(trial);
end
else
if csType(trial) == 1
v(1, trial+1) = v(1, trial) + lrp1*srpe + mu(trial);
v(2, trial+1) = v(2, trial) + lrp2*srpe + mu(trial);
else
v(1, trial+1) = v(1, trial) + mu(trial);
v(2, trial+1) = v(2, trial) + mu(trial);
end
end
end
z1 = find(csType == 1);
z2 = find(csType == 2);
if modelV(1) == 1
exp1 = params(3);
exp2 = params(4);
v(1, z1) = v(1, z1).*exp(-exp1*([0 1:(length(z1)-1)]));
v(2, z2) = v(2, z2).*exp(-exp2*([0 1:(length(z2)-1)]));
end
ll = 0;
for trial = 1 : length(scr)
ll = ll + abs(v(csType(trial), trial) - scr(trial)).^2;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For fitting Pearce-Hall model to data:
function [ll, v] = fitSCRfPHG(params, scr, us, csType, vi, mu)
v(1:2,1) = vi;
alphpos = 1;
alphneg = 1;
for trial = 1 : length(scr)
rpe = us(trial) - v(csType(trial), trial);
if csType(trial) == 1
v(1, trial+1) = v(1, trial) + params(1)*alphpos*(rpe) + mu(trial);
v(2, trial+1) = v(2, trial) + params(2)*alphneg*(rpe) + mu(trial);
alphpos = (1-params(3))*alphpos + abs(rpe)*params(3);
alphneg = (1-params(4))*alphneg + abs(rpe)*params(4);
else
v(1, trial+1) = v(1, trial) + mu(trial);
v(2, trial+1) = v(2, trial) + mu(trial);
end
end
ll = 0;
for trial = 1 : length(scr)
ll = ll + abs(v(csType(trial), trial) - scr(trial)).^2;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For fitting Li et al., models
function [ll, v, alpha_all, vals_all, a_pos, a_neg, warningTracker, b] = fitSCRfPHG2(params, scr, us, csType, vi, mu, modelinfo)
%saves associability parameters trial by trial 9.27.21. DB
%modified to match Li et al., 2011 models 02.02.2022
v(1:2,1) = vi;
alphpos = 1;
alphneg = 1;
a_pos=zeros(length(scr),1);
a_neg=zeros(length(scr),1);
alpha_all=zeros(length(scr),1);
for trial = 1 : length(scr)
%RPE(trial) = us(trial) - v(csType(trial), trial); %added for troubleshooting
%rpe = us(trial) - v(csType(trial), trial); %original
%making us binary 02.15.2022 for Li et al. models:
if us(trial)>0
r=1;
else
r=0;
end
rpe = r - v(csType(trial), trial);
if csType(trial) == 1
v(1, trial+1) = v(1, trial) + params(1)*alphpos*(rpe) + mu(trial); %params(1)=kappa for cs+
v(2, trial+1) = v(2, trial) + params(2)*alphneg*(rpe) + mu(trial); %params(2)=kappa for cs-
alphpos = (1-params(3))*alphpos + abs(rpe)*params(3); %params(3)=gamma cs+
alphneg = (1-params(4))*alphneg + abs(rpe)*params(4); %params(4)=gamma cs-
else
v(1, trial+1) = v(1, trial) + mu(trial);
v(2, trial+1) = v(2, trial) + mu(trial);
end
a_pos(trial,1)=alphpos; %trial by trial update of each alpha for cs+ and cs-
a_neg(trial,1)=alphneg;
end
ll = 0;
%Interleaving Vs and Alphas for CS+ and CS- trials to do just 1 regression :
for trial = 1 : length(scr)
if csType(trial)==2
alpha_all(trial,1)=a_neg(trial);
elseif csType(trial)==1
alpha_all(trial,1)=a_pos(trial);
end
vals_all(trial,1)=v(csType(trial), trial);
end
%with regression for both (use one of these at a time):
if modelinfo==1
[b, bint, r]=regress(scr, [ones(length(scr),1), vals_all]); %V
%return betas separately
elseif modelinfo==2
[b, bint, r]=regress(scr, [ones(length(scr),1), alpha_all]); %a
elseif modelinfo==3
[b, bint, r]=regress(scr, [ones(length(scr),1), alpha_all, vals_all]); %V+a
end
ll=sum(r.^2); %minimize the square of the residuals if using regression in function
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%For fitting Tzovara et al., models
%wrapper function for compatibility with fminsearch:
function [ll] = fitSCRUnc2(params, scr, us, ucData, csType, vi, habituationFlag)
[ll, ~] = fitSCRfUnc2(params, scr, us, ucData, csType, vi, habituationFlag);
end
function [ll, yhat] = fitSCRfUnc2(params, scr, scr_all, ucData, csType, vi, habituationFlag)
%scr is only CS+ trials (length 10)
%scr_all is CS+ and CS- trials (length 20)
%csType is length 20 for all trials 1=CS+ 2=CS-
%note that ucData is the same as the input us in the original version of
%this function above.
if csType(1) == 3
decay = params(3);
end
if habituationFlag==1
exp1=params(5);
end
%CS+ params
v_CSP = zeros(1, length(scr) + 1);
E_theta_CSP = zeros(1,length(scr) + 1);
v_CSP(1) = 0;
beta_one_CSP = params(1);
beta_knot_CSP = params(2);
non_us = ucData == 0; %ucData is binary so for non_us 1=CS+ reinforced with UCS 0=not reinforced 80/20 schedule
yhat_CSP = zeros(length(scr)+1,1);
alpha_CSP = zeros(length(scr),1); alpha_CSP(1) = 1;
beta_CSP = zeros(length(scr),1); beta_CSP(1) = 1;
%CS- params
v_CSN = zeros(1, length(scr) + 1);
E_theta_CSN = zeros(1,length(scr) + 1);
v_CSN(1) = 0;
beta_one_CSN = params(3);
beta_knot_CSN = params(4);
non_stim=ones(1,length(ucData)); %cs- never reinforced
yhat_CSN = zeros(length(scr)+1,1);
alpha_CSN = zeros(length(scr),1); alpha_CSN(1) = 1; %not sure if these should be initialized at 1 or 0 (Bayesian learner in Tzovara 2018)
beta_CSN = zeros(length(scr),1); beta_CSN(1) = 1;
%CS+ (csType=1), compute alpha and beta:
for trial = 1 : length(scr)
alpha_CSP(trial+1) = alpha_CSP(trial) + ucData(trial); %goes up every time ucs was delivered
beta_CSP(trial+1) = beta_CSP(trial) + non_us(trial); %beta doesn't change when ucs delivered, goes up by 1 otherwise
E_theta_CSP(trial + 1) = (alpha_CSP(trial))/(alpha_CSP(trial) + beta_CSP(trial));
v_CSP(trial+1) = -log(alpha_CSP(trial) + beta_CSP(trial));
end
%for CS- only increment beta, it was never reinforced with ucs
%alpha stays at 1 but beta increments up
%CS- (csType=2)
for trial = 1 : length(scr)
alpha_CSN(trial+1) = 1; %cs- never reinforced
beta_CSN(trial+1) = beta_CSN(trial) + non_stim(trial); %just increments every trial
E_theta_CSN(trial + 1) = (alpha_CSN(trial))/(alpha_CSN(trial) + beta_CSN(trial));
v_CSN(trial+1) = -log(alpha_CSN(trial) + beta_CSN(trial));
end
if habituationFlag==1
%add habituation to CS+
vec = [0 0]; n = 2; %don't habituate on first 2 trials?
%vec = zeros(n,1)'; %other option for this
v_CSP = v_CSP.*exp(-exp1*([vec 1:(length(v_CSP) - n)])); %habituation
%v=v*exp(alpha(k-ko)^t) ko=2. ko not fit to individual subjects
end
ll = 0;
%compute yhat for CS+ and CS- combined by interleaving yhat from CSP and CS-
for trial = 1:length(scr)
z_CSP(trial) = v_CSP(trial) + E_theta_CSP(trial);
z_CSN(trial) = v_CSN(trial) + E_theta_CSN(trial);
yhat_CSP(trial) = beta_one_CSP * z_CSP(trial) + beta_knot_CSP;
yhat_CSN(trial) = beta_one_CSN * z_CSN(trial) + beta_knot_CSN;
end
yhat_CSP=yhat_CSP(1:10);
yhat_CSN=yhat_CSN(1:10);
%combine CSP and CSN yhats into one yhat, properly interleaved:
csp_idx=find(csType==1);
csn_idx=find(csType==2);
yhat(csp_idx,1)=yhat_CSP;
yhat(csn_idx,1)=yhat_CSN;
for trial=1:length(scr_all)
%calculate ll
ll = ll + ((yhat(trial) - scr_all(trial)).^2);
end
yhat(1) = vi;
end