https://github.com/jdiedrichsen/pcm_toolbox
Raw File
Tip revision: 4e290a8b2c0d0820f868b7bcb60a3da7bb30e6ee authored by Jörn Diedrichsen on 26 April 2023, 01:59:24 UTC
Update pcm_estimateRegression.m
Tip revision: 4e290a8
pcm_fitModelIndividCrossval.m
function [T,DD,theta_hat,theta0]=pcm_fitModelIndividCrossval(Y,M,partitionVec,conditionVec,varargin);
% function [T,D,theta_hat]=pcm_fitModelIndividCrossval(Y,M,partitionVec,conditionVec,varargin);
% Fits pattern component model(s) specified in M to data from one (or more)
% subjects individually, using leave-one out crossvalidation within each
% subject.
%==========================================================================
% INPUT:
%        Y: [#Conditions x #Voxels]
%            Observed/estimated beta regressors from each subject.
%            Preferably multivariate noise-normalized beta regressors.
%            If it's a cell array, it is assuming multiple subjects, each
%            cell containing the data from one subject
%
%        M: {#Models} Cell array with structure that defines model(s). Each
%        may contain the fields
%              .type:        Type of the model to be fitted
%                             'fixed':     Fixed structure without parameter (except scale for each subject)
%                             'component': G is a sum of linear components
%                             'feature':   G=A*A', with A a linear sum of weighted feature components
%                             'nonlinear': Nonlinear model with own function to return derivatives
%                             'freechol':  Free model in Cholesky form
%              .numGparams:  Scalar that defines the number of parameters
%                             included in model.
%              .theta0:      Vector of starting theta parameters to calculate predicted
%                             model G.
%         for more fields see the manual for model specification.
%
%   partitionVec: Partition assignment vector
%                   Could be a cell array of multiple vector if its for multiple
%                   subjects and the parition structure is different.
%                   Rows of partitionVec{subj} defin partition assignment of rows of Y{subj}.
%                   Commonly these are the scanning run #s for beta
%                   regressors.
%                   If a single vector is provided, it is assumed to me the
%                   same for all subjects
%                   The runs are assume to be labeled 1-numRuns
%
%   conditionVec:  Condition assignment vector
%                   for each subject. Rows of conditionVec{subj} define
%                   condition assignment of rows of Y{subj}.
%                   If a single vector is provided, it is assumed to me the
%                   same for all subjects
%                   If the (elements of) conditionVec are matrices, it is
%                   assumed to be the design matrix Z, allowing the
%                   specification individualized models.
%--------------------------------------------------------------------------
% OPTION:
%   'crossvalScheme': Crossvalidation scheme on the different partitions 
%                  'leaveOneOut': Leave one partition at a time out 
%                  'leaveTwoOut': Leave two consecutive partitions out
%                  'oddeven': Split half by odd and even runs 
%   'evaluation':   List of evaluation criteria 
%                   'likelihood': Using log predictive probability (pseudo
%                   likelihood) 
%                   'R2': Crossvalidated R^2 
%                   'R' : Correlation between observed and predicted  
%   'runEffect': How to deal with effects that may be specific to different
%                imaging runs:
%                  'random': Models variance of the run effect for each subject
%                            as a seperate random effects parameter.
%                  'fixed': Consider run effect a fixed effect, will be removed
%                            implicitly using ReML.
%   'MaxIteration': Number of max minimization iterations. Default is 1000.
%
%   'S',S         : (Cell array of) NxN noise covariance matrices -
%                   otherwise independence is assumed
%   'fitAlgorithm': Either 'NR' or 'minimize' - provides over-write for
%                   model specific algorithms
%   'verbose':      Optional flag to show display message in the command
%                   line. Default is 1. Setting to 2 gives more detailed
%                   feedback on pcm_NR
%   'theta0':       Cell array of starting values (same format as theta{m})
%--------------------------------------------------------------------------
% OUTPUT:
%   T:      Summary crossvalidation results (1 per subject)
%       SN:                 Subject number
%       likelihood:         crossvalidated likelihood
%       noise:              Noise parameter
%       run:                Run parameter (if run = 'random')
%       iterations:         Number of interations for model fit
%       time:               Elapsed time in sec
%   D:  Detailed crossvalidation results (1 per condition)    
%       fold:               Crossvalidation fold 
%
%  theta_hat{model}:  Estimated parameters (model + scaling/noise parameters)
%                       for individual subjects 

runEffect       = 'random';
isCheckDeriv    = 0;
MaxIteration    = 1000;
Iter            = [];
verbose         = 1;
S               = [];
fitAlgorithm    = [];
crossvalScheme  = 'leaveOneOut'; 
evaluation      = {'likelihood','R2','R'}; 
theta0          = {}; 
pcm_vararginoptions(varargin,{'crossvalScheme','fitAlgorithm','runEffect',...
            'MaxIteration','verbose','S','evaluation','crossvalScheme','theta0'});

DD=[]; 
% Number of subejcts 
if (~iscell(Y))
    Y={Y};
end; 
numSubj     = numel(Y);

% Number of evaluation criteria 
if (~iscell(evaluation)) 
    evaluation={evaluation}; 
end; 
numEval = numel(evaluation); 

% Preallocate output structure
T.SN = [1:numSubj]';

% Get the number of models
if (~iscell(M))
    M={M};
end;
numModels   = numel(M);

% Determine optimal algorithm for each of the models
if (~isempty(fitAlgorithm))
    for m=1:numModels
        M{m}.fitAlgorithm = fitAlgorithm;
    end;
end;
M = pcm_optimalAlgorithm(M);

% Set up all parameters for the upcoming fit
[Z,B,X,YY,S,N,P,G_hat,noise0,run0]=...
    pcm_setUpFit(Y,partitionVec,conditionVec,'runEffect',runEffect,'S',S);

% Loop over subject and provide inidivdual fits
for s = 1:numSubj 
    
    % Set up crossvalidation scheme
    if iscell(partitionVec) 
        pV=partitionVec{s}; 
    else 
        pV=partitionVec; 
    end; 
    part = unique(pV)';
    numPart = numel(part);
    if ischar(crossvalScheme)
        partI={};
        switch (crossvalScheme)
            case 'leaveOneOut'
                for i=1:numPart
                    partI{i}=part(i);
                end;
            case 'leaveTwoOut'
                for i=1:floor(numPart/2)
                    partI{i}=part((i-1)*2+[1:2]);
                end;
            case 'oddEven'
                for i=1:2
                    partI{i}=part(mod(part+i,2)==0);
                end;
        end;
    else
        partI = crossvalScheme; % Direct specificiation
    end;
    numFolds = numel(partI);
    
    % Now loop over models
    for m = 1:length(M)
        if (verbose)
            if isfield(M{m},'name');
                fprintf('Fitting Subj: %d model:%s\n',s,M{m}.name);
            else
                fprintf('Fitting Subj: %d model:%d\n',s,m);
            end;
        end;
        tic;
        
        % Set options
        if (~isempty(S))
            OPT.S = S;
        end;
        
        % Get starting guess for theta if not provided
        if (numel(theta0)<m || size(theta0{m},2)<s) 
            if (isfield(M{m},'theta0'))
                th0m = M{m}.theta0(1:M{m}.numGparams);
            else
                th0m = pcm_getStartingval(M{m},G_hat(:,:,s));   
            end;
        
            if (strcmp(runEffect,'random'))
                theta0{m}(:,s) = [th0m;noise0(s);run0(s)];
            else
                theta0{m}(:,s) = [th0m;noise0(s)]; 
            end;
        end; 

        % Set starting values
        x0 = theta0{m}(:,s);
        th = nan(size(x0,1),numFolds); 
        
        % Now perform the cross-validation across different partitions
        for p=1:numFolds
            trainIdx = ~ismember(pV,partI{p}); 
            testIdx = ismember(pV,partI{p}); 
            
            % Get the data and design matrices for training set 
            Ytrain=Y{s}(trainIdx,:); 
            Xtrain=reduce(X{s},trainIdx);
            Ztrain=Z{s}(trainIdx,:);
            OPT.runEffect = reduce(B{s},trainIdx); 

            % Get test data and design matrices for the test set 
            Ytest=Y{s}(testIdx,:);
            Xtest=reduce(X{s},testIdx);
            Ztest=Z{s}(testIdx,:);
            Btest=reduce(B{s},testIdx);
            Ntest = size(Ztest,1); 
            
            % Perform the initial fit to the training data 
            switch (M{m}.fitAlgorithm)
                case 'minimize'  % Use minimize to find maximum liklhood estimate runEffect',B{s});
                    fcn = @(x) pcm_likelihoodIndivid(x,Ytrain*Ytrain',M{m},Ztrain,Xtrain,P(s),OPT);
                    l=fcn(x0); 
                    [th(:,p),fX,D.iterations(p,m)] =  minimize(x0, fcn, MaxIteration);
                    D.likelihood_fit(p,m)=-fX(end); 
                case 'NR'
                    fcn = @(x) pcm_likelihoodIndivid(x,Ytrain*Ytrain',M{m},Ztrain,Xtrain,P(s),OPT);
                    [th(:,p),D.likelihood_fit(p,m),D.iterations(p,m)]= pcm_NR(x0,fcn,'verbose',verbose);
            end;
            x0 = th(:,p); 
            % Record the stats from fitting
            D.SN(p,1)         = s;
            D.fold(p,1)       = p;
            D.noise(p,m)      =  exp(th(M{m}.numGparams+1,p));
            if strcmp(runEffect,'random')
                D.run(p,m)      =  exp(th(M{m}.numGparams+2,p));
            end;
            D.time(p,m)       = toc;
            
            % calculate prediction on 
            [estU,varU] = pcm_estimateU(M{m},th(:,p),Ytrain,Ztrain,Xtrain,'runEffect',OPT.runEffect);
            Ypred  = Ztest*estU;
            Ypredx = Ypred-Xtest*pinv(Xtest)*Ypred; 
            Ytestx = Ytest-Xtest*pinv(Xtest)*Ytest;
            for c = 1:numEval 
                switch (evaluation{c})
                    %  Under developement 
                    case 'likelihood' % Evaluate log probability under the predictive probability
                        R = Ytest - Ztest * estU; % Residuals after subtracting prediction 
                        V = Ztest * varU * Ztest'+eye(Ntest)*D.noise(p,m); 
                        if strcmp(runEffect,'random')
                            V = V + Btest*Btest'*D.run(p,m);
                        end
                        iV = inv(V);
                        if (~isempty(X))
                            iVX   = iV * Xtest;
                            iVr   = iV - iVX*((Xtest'*iVX)\iVX');
                        else
                            iVr   = iV;
                        end
                        
                        % Computation of (restricted) likelihood
                        ldet  = -2* sum(log(diag(chol(iV))));        % Safe computation of the log determinant (V) Thanks to code from D. lu
                        l     = -P(s)/2*(ldet)-0.5*traceABtrans(iVr,R*R');
                        if (~isempty(X)) % Correct for ReML estimates
                            l = l - P(s)*sum(log(diag(chol(Xtest'*iV*Xtest))));  % - P/2 log(det(X'V^-1*X));
                        end
                        D.likelihood(p,m) = l;
                    case 'R2'              % Predictive R2 
                        D.TSS(p,m)= sum(sum(Ytestx.*Ytestx));
                        D.RSS(p,m)= sum(sum((Ytestx-Ypredx).^2)); 
                    case {'R','Rpool'}     % Predictive correlation 
                        D.SS1(p,m) = sum(sum(Ytestx.*Ytestx));
                        D.SS2(p,m) = sum(sum(Ypredx.*Ypredx));
                        D.SSC(p,m) = sum(sum(Ypredx.*Ytestx));
                end; 
            end;
            
            % Use last iterations as a parameter starting value
            % x0 = th(:,p);
        end;                % For each partition 
        theta_hat{m}(:,s)=mean(th,2);
    end;                    % For each model 
    DD=addstruct(DD,D); 
    
    % Summarize results across partitions for each subject
    T.noise(s,:)=mean(D.noise); 
    if strcmp(runEffect,'random')
        T.run(s,:) =  mean(D.run);
    end;
    T.time(s,:)    =  sum(D.time);
    T.iterations(s,:)    =  sum(D.iterations);
    for c = 1:numEval 
        switch (evaluation{c})
            case 'likelihood'
                T.(evaluation{c})(s,:)    =  sum(D.(evaluation{c}));
            case 'R2'
                TSS = sum(D.TSS); 
                RSS = sum(D.RSS); 
                T.R2(s,:)    = 1-RSS./TSS;  
            case 'Rpool' % Pool sums-of-squares first, then calculate correlations: 
                % Warning: This measure is slightly negatively biased for
                % noise data, but can perform more stable in model
                % comparision 
                SSC = sum(D.SSC); 
                SS1 = sum(D.SS1); 
                SS2 = sum(D.SS2); 
                T.SS1(s,:)=SS1; 
                T.SS2(s,:)=SS2; 
                T.SSC(s,:)=SSC; 
                T.Rpool(s,:)    = SSC./sqrt(SS1.*SS2); 
            case 'R'  % Predictive correlation for each fold - then pool correlation 
                T.R(s,:)    = mean(D.SSC./sqrt(D.SS1.*D.SS2)); 
        end;
    end; 
end; % for each subject


function Xt=reduce(X,index); 
    Xt=X(index,:);
    Xt=Xt(:,sum(abs(Xt))>0); 
back to top