https://github.com/mullerlab/spindlecnn
Raw File
Tip revision: 0c503d103e4a0cf041e43903a896bb25b0c66b9b authored by Lyle Muller on 16 June 2022, 01:07:31 UTC
Create LICENSE
Tip revision: 0c503d1
fitCNN.m
function [net] = fitCNN(Xtrain, ytrain, Xval, yval, Hyperparams)

%
% FIT A CNN MODEL TO DETECT SPINDLES
% 22 MARCH 2022
%
% PARAMETERS
% Xtrain, ytrain - training set 
% Xval, yval - validation set
% Hyperparams - CNN Hyperparams
    % cov_fsize - convolution layer - filter size
    % cov_numf - convolution layer - number of filters
    % pool_size - pool size
    % opt_solver - solver n,ae
    % opt_learning_rate - learning rate
    % opt_max_epoch - maximum number of epochs to use for training,
%   
%
% OUTPUTS
% net - CNN model
%

% read hyperparamters
keys = Hyperparams.keys;

% CNN filter size
if ismember('cov_fsize', keys)
    cov_fsize = Hyperparams('cov_fsize');
else
    cov_fsize = [2,16,32,16,4];
end

% CNN number of features
if ismember('cov_numf', keys)
    cov_numf = Hyperparams('cov_numf');
else
    cov_numf = [32,64,128,192,256];
end

% pool size
if ismember('pool_size', keys)
    pool_size = Hyperparams('pool_size');
else
    pool_size = 4; 
end

% optimization solver
if ismember('opt_solver', keys)
    opt_solver = Hyperparams('opt_solver');
else
    opt_solver = 'sgdm'; 
end

% learning rate
if ismember('opt_learning_rate', keys)
    opt_learning_rate = Hyperparams('opt_learning_rate');
else
    opt_learning_rate = 0.001; 
end

% maximum epochs
if ismember('opt_max_epoch', keys)
    opt_max_epoch = Hyperparams('opt_max_epoch');
else
    opt_max_epoch = 15; 
end


% 
% input init
inputSize = [1 size(Xtrain,2) 1]; numClasses = 2;

%%%% Create the Array of Layers
layers = [
    imageInputLayer(inputSize,"Name","imageinput")
    convolution2dLayer([1 cov_fsize(1)], cov_numf(1),"Name","conv_1","Padding","same") % 4 is the best number
    maxPooling2dLayer([pool_size pool_size],"Name","maxpool_1","Padding","same","Stride",[1 2])
    reluLayer("Name","relu_1")
    convolution2dLayer([cov_fsize(2) cov_fsize(2)],cov_numf(2),"Name","conv_2","Padding","same")
    maxPooling2dLayer([pool_size pool_size],"Name","maxpool_2","Padding","same","Stride",[1 2])
    reluLayer("Name","relu_2")
    convolution2dLayer([cov_fsize(3) cov_fsize(3)],cov_numf(3),"Name","conv_3","Padding","same")
    maxPooling2dLayer([pool_size pool_size],"Name","maxpool_3","Padding","same","Stride",[1 2])
    reluLayer("Name","relu_3")
    convolution2dLayer([cov_fsize(4) cov_fsize(4)],cov_numf(4),"Name","conv_4","Padding","same")
    maxPooling2dLayer([pool_size pool_size],"Name","maxpool_4","Padding","same","Stride",[1 2])
    reluLayer("Name","relu_4")
    convolution2dLayer([cov_fsize(5) cov_fsize(5)],cov_numf(5),"Name","conv_5","Padding","same")
    maxPooling2dLayer([pool_size pool_size],"Name","maxpool_5","Padding","same","Stride",[1 2])
    reluLayer("Name","relu_5")
    fullyConnectedLayer(128,"Name","fc_1")
    reluLayer("Name","relu_6")
    fullyConnectedLayer(64,"Name","fc_2")
    reluLayer("Name","relu_7")
    fullyConnectedLayer(32,"Name","fc_3")
    reluLayer("Name","relu_8")
    fullyConnectedLayer(numClasses,"Name","fc_4")
    softmaxLayer("Name","softmax")
    classificationLayer("Name","classoutput")];

% set options
options = trainingOptions(opt_solver, 'InitialLearnRate',opt_learning_rate, ...
    'MaxEpochs',opt_max_epoch, 'Shuffle','every-epoch', ...
    'ValidationData',{Xval,yval}, 'ValidationFrequency',200, ...
    'Verbose',false, 'Plots','training-progress');

% train cnn model
rng('default'); net = trainNetwork(Xtrain, ytrain, layers, options);



end
back to top