https://git.exeter.ac.uk/mv286/kndy-parameter-inference
Raw File
Tip revision: f5585f5f09d62a2569d21def31dc779774bc2476 authored by mv286 on 15 July 2021, 12:46:14 UTC
Deleting README.md
Tip revision: f5585f5
my_abc_SMC.m
% Function implementing the ABC SMC algorith (https://doi.org/10.1098/rsif.2008.0172)
% Parameters:
% drawProposalFun - function handle for drawing proposal particles
% SMCparams - structure holding various parameters
% modelFun - function handle for running the model
% distanceFun - function handle for calculating the distance function
% data - structure holding the data
% The function returns the sampled particles and their corresponding
% weights
function [particles, weights] = ...
    my_abc_SMC(drawProposalFun, SMCparams, ...
    modelFun, discrepancyFun, data) 
    
    %model parameters
    modelparams = SMCparams.modelparams;
    %number of parameters to be inferred 
    paramsN = SMCparams.paramN;
    %number of particles
    particlesN = SMCparams.particles;
    %series of epsilons
    epsilons = SMCparams.epsilons;
    
    datasize = length(data.X);
    %particles array
    particles = nan(particlesN, size(epsilons,1), paramsN);
    weights =   nan(particlesN, size(epsilons,1));
    model_output = nan(particlesN,size(epsilons,1), datasize);
    errors =    nan(particlesN, size(epsilons,1), size(epsilons,2));
    covs  =  nan(paramsN, paramsN, particlesN, size(epsilons,1));
    

    %% 
    for ti = (1):1:size(epsilons,1)
        
       %% for each particle
       pi = 1;
       while pi <= particlesN
           if (ti == 1)
               %sample from prior
                newParticle = drawProposalFun([], SMCparams,ti);
                 
           else
               % Resample particle from previous population according to their weights
                parentsInd = randsample( particlesN,1, true, weights(:, ti-1));
                particleTmp = squeeze(particles(parentsInd, ti-1, :));
                SMCparams.Sigma = covs(:,:,parentsInd,ti-1);
                newParticle = drawProposalFun(particleTmp,SMCparams, ti);                
           end
           
           if ((sum(newParticle(1:4)<SMCparams.minV(1:4) | newParticle(1:4)>SMCparams.maxV(1:4)) == 0) ...
              && (sum((newParticle(1:4)+newParticle(5:8))<SMCparams.minV(1:4) | (newParticle(1:4)+newParticle(5:8))>SMCparams.maxV(1:4)) == 0) ...
              && sum(newParticle(9)<SMCparams.minV(9) | newParticle(9)>SMCparams.maxV(9)) == 0 )
                   
                % - run model
                modelparams.fitparams = newParticle;
                
                Xvals = modelFun(modelparams) ;
                
                
                % - calculate discepancy
                d = discrepancyFun(data.X , Xvals(1,:), modelparams );
                 
                %if within error 
                if (   all( d< epsilons(ti,:) )   )
                    model_output(pi,ti,:) = Xvals(1,:);
                    % - update particle
                    particles(pi, ti, :) = newParticle;
                    % - calculate new weight
                    if ti == 1
                        weights(pi, ti) = 1;
                    else
                        
                       
                        KK = mvnpdf(squeeze(particles(:,ti-1,:)),newParticle', SMCparams.Sigma);
                        r1=1;%normpdf(newParticle(5) - modelparams.p2, 0, 0.1*modelparams.p2);
                        r2=1;%normpdf(newParticle(8) - modelparams.KD, 0, 0.1*modelparams.KD);
                        r3=1;%normpdf(newParticle(9) - modelparams.KN, 0, 0.1*modelparams.KN);
                        r4=1;%normpdf(newParticle(10) - modelparams.Kr1, 0, 0.1*modelparams.Kr1);
                        r5=1;%normpdf(newParticle(5) - modelparams.p3Basal, 0, 0.1*modelparams.p3Basal);
                        weights(pi, ti) = r1*r2*r3*r4*r5./sum(weights(:, ti-1).*KK);

                        
                         
                    end
                    
                    %store error value
                    errors(pi, ti,:) = d;
                    %advance particle index
                    pi = pi + 1
                end
           end
           
           
           
       end
       
       %normalise weights
       weights(:, ti) = weights(:, ti)./sum(weights(:, ti));
     
       
       % update perturbation kernel for each particle using knn estimate
       all_current_p = squeeze(particles(:,ti,:));
       for pi = 1:1:particlesN
           pp = squeeze(particles(pi,ti,:))';
           
           pp_neigh = knnsearch(all_current_p, pp, 'K',SMCparams.Knn);
           covs(:,:,pi, ti) = cov(squeeze(particles(pp_neigh,ti,:))) + 1e-12 * eye(paramsN) ;
       end
        
       % save results
       save([SMCparams.prefix '_' num2str(ti) '.mat'])
    end
back to top