https://github.com/Klimmasch/AEC
Tip revision: 96e9ae2336937469a8f1602c178ea5e0cb8564b6 authored by Lukas Klimmasch on 13 August 2021, 14:16:04 UTC
Merge branch 'alternateRearing' of https://github.com/Klimmasch/AEC into alternateRearing
Merge branch 'alternateRearing' of https://github.com/Klimmasch/AEC into alternateRearing
Tip revision: 96e9ae2
SparseCoding2.m
classdef SparseCoding2 < handle
properties
nBasis; % total number of basis
nBasisUsed; % number of basis used to encode images in sparse mode
basisSize; % size of each (binocular) base vector: patchSize * patchSize * 2 (left + right eye)
eta; % learning rate
temperature; % temperature in softmax
basis; % all basis functions
basisHist; % basis functions history
currentCoef; % current coeffient matrix
currentError; % current reconstruction error
sizeBatch; % image batch size's 2nd dimension
selectedBasis; % indicates for each basis how often it has been selected
BFinit; % describes the initialization of basis functions
BFfitFreq; % fit frequencies to BFs or wavelengths?
BFIdent; % helps identifying the scale inside the scarse coder
end
methods
% Constructor
% PARAM = [nBasis, nBasisUsed, basisSize, eta, temperature, sizeBatch]
function obj = SparseCoding2(PARAM)
obj.nBasis = PARAM(1);
obj.selectedBasis = zeros(PARAM(1),1);
obj.nBasisUsed = PARAM(2);
obj.basisSize = PARAM(3);
obj.eta = PARAM(4);
obj.temperature = PARAM(5);
obj.BFinit = PARAM(6);
obj.BFfitFreq = PARAM(7);
obj.BFIdent = PARAM(8);
obj.sizeBatch = PARAM(9); % is always in the last entry of the inout array
if obj.BFinit == 1 % white noise BF init
obj.basis = rand(obj.basisSize, obj.nBasis) - 0.5;
obj.basis = obj.basis * diag(1 ./ sqrt(sum(obj.basis .* obj.basis)));
tmpNorm = ones(obj.basisSize, 1) * sqrt(sum(obj.basis .* obj.basis, 1));
obj.basis = obj.basis ./ tmpNorm;
elseif obj.BFinit == 2 % non-aligned Gabor wavelets
obj.basis = BaseGenerator(0, 0, sqrt(PARAM(3)/2), PARAM(1));
elseif obj.BFinit == 3 % monocular Gabor wavelets
obj.basis = BaseGenerator(0, 0, sqrt(PARAM(3)/2), PARAM(1));
x = randperm(obj.nBasis);
for b = 1 : obj.nBasis/2
obj.basis(1:end/2, x(b)) = 0; % monocular right
obj.basis(end/2+1:end, x(b+(obj.nBasis/2))) = 0; % monocular left
end
obj.basis = bsxfun(@rdivide,obj.basis, sqrt(sum(obj.basis .^ 2)));
elseif obj.BFinit == 4 % take preloaded BFs
% fixed 3 deg strabism
% model = load('/home/aecgroup/aecdata/Results/eLifePaper/strabism/19-02-18_500000iter_2_AllfixAt6m_filtB_29_strabAngle_3_seed2/model.mat');
% laplacian 3 deg strabism
% model = load('/home/aecgroup/aecdata/Results/eLifePaper/inducedStrabism/20-09-22_500000iter_2_inducedStrab_3deg_lapSig02_od05-1m/model.mat');
% monocular deprivation
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-04_500000iter_1_fsize6std_filtBoth_45_prob_1_seed1/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 41 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-04_500000iter_1_fsize6std_filtBoth_45_prob_1_seed1/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 42 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-18_500000iter_2_fsize6std_filtBoth_45_prob_1_seed2/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 43 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-19_500000iter_3_fsize6std_filtBoth_45_prob_1_seed3/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 44 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-18_500000iter_4_fsize6std_filtBoth_45_prob_1_seed4/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 45 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-05_500000iter_5_fsize6std_filtBoth_45_prob_1_seed5/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
else
error('Unrecognized basis function initialization')
end
obj.basisHist = [];
obj.currentCoef = zeros(obj.nBasis, obj.sizeBatch); %288x81
obj.currentError = zeros(obj.basisSize, obj.sizeBatch); %128x81
%TODO maybe reimplement reloading basis functions
end
%%% Encode the input images accoring to softmax distribution
% @param imageBatch: input image patches batch
function softmaxEncode(this, imageBatch)
this.currentCoef = this.currentCoef * 0; % needs to be tested if resetting is necessary
tmp = imageBatch;
for count = 1 : this.nBasisUsed
corrl = abs(this.basis' * tmp) / this.temperature;
corrl = corrl - kron(ones(this.nBasis, 1), max(corrl));
softmaxcorr = softmax(corrl);
softmaxcorr = tril(ones(this.nBasis)) * softmaxcorr - kron(ones(this.nBasis, 1), rand(1, this.sizeBatch));
softmaxcorr(softmaxcorr < 0) = 2;
[~, index] = min(softmaxcorr);
corrl = this.basis' * tmp;
linearIndex = sub2ind(size(corrl), index, 1 : this.sizeBatch);
this.currentCoef(linearIndex) = this.currentCoef(linearIndex) + corrl(linearIndex);
tmp = imageBatch - this.basis * this.currentCoef;
end
this.currentError = tmp;
end
%%% Encode the input images with the best matched basis
% @param imageBatch: input image patches batch
function sparseEncode(this, imageBatch)
% batch_size = size(Images, 2);
% Coef = zeros(this.nBasis, batch_size);
% I = Images;
% for (count = 1:this.nBasisUsed)
% corrl = this.Basis'*I;
% [~, index] = max(abs(corrl));
% alpha = diag(this.Basis(:, index)'*I);
% linearIndex = sub2ind(size(corrl), index, 1:batch_size);
% Coef(linearIndex) = Coef(linearIndex) + alpha';
% I = Images - this.Basis*Coef;
% end
% Error = I;
this.currentCoef = this.currentCoef * 0; % needs to be tested if resetting is necessary
corrl = this.basis' * imageBatch; % correlation of each basis with each patch
corrBB = this.basis' * this.basis; % correlation between basis
for count = 1 : this.nBasisUsed
[~, index] = max(abs(corrl)); % indices of bases with max correlation per patch
linearIndex = sub2ind(size(corrl), index, 1 : this.sizeBatch); % corresponding linear indicies in correlation matrix
pCorr = corrl(linearIndex); % vector of correlations per patch (coefs per patch)
this.currentCoef(linearIndex) = this.currentCoef(linearIndex) + pCorr; % calculate new correlation coefficients
corrl = corrl - bsxfun(@times, corrBB(:, index), pCorr); % (see Yu's doc)
end
this.currentError = imageBatch - this.basis * this.currentCoef; % 128x81 = 128x81 - 128x288 * 288x81
end
%%% Calculate the correlation between input image and the basis
% @param imageBatch: input image patches batch
function fullEncode(this, imageBatch)
this.currentCoef = this.basis' * imageBatch;
this.currentError = imageBatch - this.basis * this.currentCoef;
end
%%% Update basis functions
function stepTrain(this)
deltaBases = this.currentError * this.currentCoef' / size(this.currentError, 2);
this.basis = this.basis + this.eta * deltaBases;
this.basis = bsxfun(@rdivide, this.basis, sqrt(sum(this.basis .^ 2)));
% also update the selected basis functions
usedBasis = zeros(size(this.currentCoef));
usedBasis(find(this.currentCoef)) = 1;
usedBasis = sum(usedBasis, 2);
this.selectedBasis = this.selectedBasis + (usedBasis ./ sum(usedBasis));
this.selectedBasis = this.selectedBasis ./ sum(this.selectedBasis);
end
%%% Track the evolution of all basis functions over time
function saveBasis(this)
this.basisHist = cat(3, this.basisHist, this.basis);
end
%%% Display the Basis functions (Zhao Yu code) at iteration t
function displayBasis(this, t)
%how to arrange the basis (rows, col)
R = 16;
C = 18;
len = 1;
% basisTrack = this.drecord.basisTrack(1:len);
basisTrack{1} = this.basis;
%checkPoint = 1;
endBasis = basisTrack{end}(1 : end / 2, :);
leftEnergy = abs(sum(endBasis .^ 2) - 0.5);
[~, I] = sort(leftEnergy);
% h = gcf;
% set(h,'Position',[1 1 800 600]);
% scrsz = get(0,'ScreenSize');
% set(h,'Position',[scrsz(1) scrsz(2) scrsz(3) scrsz(4)]);
subplot(1, 1, 1);
[di, num] = size(basisTrack{1});
fun1 = @(blc_struct) padarray(padarray(reshape(permute(padarray(reshape(blc_struct.data, sqrt(di / 2), ...
sqrt(di / 2), 2), [1, 1], 'pre'), [1, 3, 2]), (sqrt(di / 2) + 1) * 2, sqrt(di / 2) + 1), ...
[1, 1], 'post') - 1, [1 1], 'pre') + 1;
for j = 1 : len
A = basisTrack{j}(:, I);
% B = reshape(A, di*sqrt(num/2), sqrt(num/2)*2);
B = reshape(A, di * R, C);
B = B / max(max(abs(B))) + 0.5;
C = padarray(padarray(blockproc(B, [di, 1], fun1) - 1, [1 1],'post') + 1,[2, 2]);
imshow(C);
% title(num2str(checkPoint(j)));
title(num2str(t));
drawnow;
end
end
end
end