https://github.com/higham/logsumexp-softmax-tests
Tip revision: a6e66ff1ff1d30800406481af2cf8720c6221691 authored by Nick Higham on 26 August 2020, 14:24:16 UTC
Updated reference in readme.
Updated reference in readme.
Tip revision: a6e66ff
chop_test_softmax.m
% CHOP_TEST_SOFTMAX.M
%
% USES CHOP TO LOWER THE PRECISION
% FROM THE PACKAGE:
% Simulating Low Precision Floating-Point Arithmetic,
% N J Higham and S. Pranesh,
% MIMS EPrint 2019.4, March 2019; revised July 2019,
% to appear in SIAM J. Sci. Comput
%
% Data came from a LeCun style network that was trained on MINST data
% (GENERATED BY PRESOFTMAX.M)
%
% Here softmat is 10 by 2500, each column is the ten numbers
% that would go into the final softmax layer.
% Note: softmax is a single precision matrix
%
% May 2019
% Latest version June 2019, used for manuscript
% This one July 2019, now Aug 2019, now Sep 2019
%
% This version uses fp16 for the lower precision calculations.
load('data_presoftmax.mat')
%softmax in the single precision data matrix,
%each column is a vector x of length 10
us = 5.96*1e-8; % unit roundoff for single precision fp32
uh = 4.88*1e-4 % unit roundoff for half precision fp16
options.format = 'h'; % to half precision
chop([],options); % set up chop and now use x = chop(x);
n = 10; % dimension of data
L = 2500 % number of test vectors used: maximum possible is 2500
% Some test vectors are going to produce Infs--deal with that at the end
% Now initialize some storage
lsechopval = zeros(L,1); lsechopshiftval = zeros(L,1);
lsebd1 = zeros(L,1); errlse1 = zeros(L,1);
lseshiftbd = zeros(L,1); errlseshift1 = zeros(L,1);
smbd1 = zeros(L,1); errsm1 = zeros(L,1);
smbd1b = zeros(L,1); errsm1b = zeros(L,1);
smbdshift = zeros(L,1); errsmshift = zeros(L,1);
smbdaltshift = zeros(L,1); errsmaltshift = zeros(L,1);
sumchop = zeros(L,1); sumchopb = zeros(L,1);
sumhigh = zeros(L,1);
sumchopshift = zeros(L,1); sumaltshift = zeros(L,1);
FS = 12; % FontSize
for k = 1:L
k
xvec = softmat(:,k); %kth data point
xchop = chop(xvec) ;
%%%%%%%% test on this vector %%%%%%%
[lsechop, smchop, smchopb] = lse_chop(xchop); %chopped: basic lse and two versions of softmax
%%% Tried the version below first: uses full xvec to get reference solution
%[lsehigh, smhigh] = lse_basichigh(xvec); %single (original) prec alternative (taken to be exact solution)
%%% Changed to the version below:
%%% uses chopped vector (but works in the higher, single, precision)
[lsehigh, smhigh] = lse_basichigh(xchop); %single (original) prec alternative (taken to be exact solution)
[lsechopshift, smchopshift, smaltshift] = lse_chopshift(xchop); % chopped version of shifted alg
lsechopval(k) = lsechop; % store in order to test for Infs
lsechopshiftval(k) = lsechopshift; % store in order to test for Infs
% LSE: leading order term in error bound from Thm3.2
lsebd1(k) = 1 + (n+1)/abs(lsehigh);
errlse1(k) = abs(lsechop-lsehigh)/(uh*abs(lsehigh));
% LSE: leading order term from Thm 4.2
lseshiftbd1(k) = abs(lsehigh + n - min(xvec))/abs(lsehigh);
errlseshift1(k) = abs(lsechopshift-lsehigh)/(uh*abs(lsehigh));
% SOFTMAX: leading order term in error bound from Thm3.3
smbd1(k) = (n+3);
errsm1(k) = max(abs(smchop - smhigh))/(uh*max(abs(smhigh)));
% SOFTMAX: leading order term in error bound from Thm3.4 for
% alternative alg
smbd1b(k) = (abs(lsehigh) + max(abs(xvec - lsehigh)) + n + 2);
errsm1b(k) = max(abs(smchopb - smhigh))/(uh*max(abs(smhigh)));
% SOFTMAX: leading order term in error bound from Thm 4.3 for
% shifted alg
smbdshift(k) = n + 2 + 2*(max(xvec) - min(xvec));
errsmshift(k) = max(abs(smchopshift - smhigh))/(uh*max(abs(smhigh)));
% SOFTMAX: leading order term in error bound from Thm 4.4 for
% alternative version of shifted alg
smbdaltshift(k) = 1 + max(abs(xvec - lsehigh)) + abs(lsehigh + n - min(xvec));
errsmaltshift(k) = max(abs(smaltshift - smhigh))/(uh*max(abs(smhigh)));
% compute sum of softmax entries in each case
sumchop(k) = sum(smchop);
sumchopb(k) = sum(smchopb);
sumhigh(k) = sum(smhigh);
sumchopshift(k) = sum(smchopshift);
sumaltshift(k) = sum(smaltshift);
end
% IGNORE THE CASES WHERE CHOPPED LSE EVALUTED TO INF
a = find(lsechopval ~= Inf);
a2 = find(lsechopval == Inf);
numInfs = length(a2) %overflows in half prec basic LSE: gives 475
a3 = find(lsechopshiftval == Inf);
numInfs_shift = length(a3) %overflows in half prec shifted LSE: gives 0
% So we will only plot info for data vectors indexed by a
% (i.e., for input where all algorithms avoided overflow)
%% FIGURE 1 PRODUCES A PICTURE FOR THE MANUCSRIPT
figure(1) %%% LSE basic alg
clf
subplot(2,2,1)
plot(lsebd1(a),errlse1(a),'*')
hold on
ylim([0 4]), yticks(0:4)
xlim([1.9 6]), xticks(2:6)
lb1 = min(lsebd1(a)) % 1,9 is a lower bound in the smallest x component that we plot
plot([min(lsebd1(a)),max(lsebd1(a))],[min(lsebd1(a)),max(lsebd1(a))],'r-','LineWidth',2) % ref slope
title('Basic LSE')
ylabel('f.p. error')
xlabel('error bound')
set(gca,'FontSize',FS)
axis square
set(gca,'TickLength',[0.02 0.025])
subplot(2,2,2)
plot(lseshiftbd1(a),errlseshift1(a),'*')
hold on
ylim([0 4]), yticks(0:4)
xlim([1.9 6]), xticks(0:6)
lb2 = min(lseshiftbd1) % 1.9 is a lower bound in the smallest x component that we plot
plot([min(lseshiftbd1),max(lseshiftbd1)],[min(lseshiftbd1),max(lseshiftbd1)],'r-','LineWidth',2) % ref slope
title('Shifted LSE')
ylabel('f.p. error')
xlabel('error bound')
set(gca,'FontSize',FS)
axis square
set(gca,'TickLength',[0.02 0.025])
subplot(2,2,[3 4])
% subplot(2,2,[2 4])
plot(errlse1(a),errlseshift1(a),'*')
hold on
%plot([min(errlse1(a)),max(errlse1(a))],[min(errlse1(a)),max(errlse1(a))],'r-','LineWidth',2) % ref slope
xlabel('basic LSE')
ylabel('shifted LSE')
ylim([0 1.25]), yticks(0:0.25:1.25)
xlim([0 1.25]), xticks(0:0.5:1.25)
% axis equal
axis square
title('Basic vs Shifted')
set(gca,'FontSize',FS)
set(gca,'TickLength',[0.02 0.025])
print -dpng lsepic.png
lse_ratio = errlse1(a)./errlseshift1(a);
lse_rat_min = min (lse_ratio)
lse_rat_max = max(lse_ratio)
lse_rat_mean = mean(lse_ratio)
lse_rat_std = std(lse_ratio)/sqrt(length(a))
%%%%% FIGURE(2) IS FOR INFO, NOT FOR MANUSCRIPT
figure(2) % another comparison of the errors in the two methods
clf
%histogram(log10(lse_ratio))
rat_sort = sort(log10(lse_ratio),'ascend');
plot(rat_sort,'*-')
title('Log ratio of errors in LSE shifted LSE')
print -dpng lseratiopic.png
% seem to be lots of ones
rat = errlse1(a)./errlseshift1(a); % many of these are seen to be exactly 1 in f.p.
[a1,a2] = find(rat == 1);
numones = length(a1)
ratio = numones/length(rat)
%%%%% FIGURE(3) IS FOR INFO, NOT FOR MANUSCRIPT
figure(3) %%% SOFTMAX basic alg : upper bd is independent of x
clf
histogram(errsm1(a))
title('Basic SOFTMAX')
set(gca,'FontSize',24,'FontWeight','Bold')
print -dpng sf1.png
%%%%% FIGURE(4) IS FOR INFO, NOT FOR MANUSCRIPT
figure(4) %%%% SOFTMAX alternative alg
clf
subplot(2,2,1)
plot(smbd1b(a),errsm1b(a),'*')
hold on
ylim([0 50])
plot([min(smbd1b(a)),max(smbd1b(a))],[min(smbd1b(a)),max(smbd1b(a))],'r-','LineWidth',2) % ref slope
title('alt.')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')
subplot(2,2,2)
plot(smbdshift(a),errsmshift(a),'*')
hold on
ylim([0 80])
plot([min(smbdshift),max(smbdshift)],[min(smbdshift),max(smbdshift)],'r-','LineWidth',2) % ref slope
title('shifted')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')
subplot(2,2,4)
plot(errsmshift(a),errsm1b(a),'*')
hold on
ylim([0 15])
lower = min(min(errsmshift(a)),min(errsm1b(a)));
upper = max(max(errsmshift(a)),max(errsm1b(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('shifted v alt.')
xlabel('shifted')
ylabel('alt')
set(gca,'FontSize',18,'FontWeight','Bold')
subplot(2,2,3)
plot(errsmshift(a),errsm1(a),'*')
hold on
ylim([0 10])
lower = min(min(errsmshift(a)),min(errsm1(a)));
upper = max(max(errsmshift(a)),max(errsm1(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('shifted v basic')
xlabel('shifted')
ylabel('basic')
set(gca,'FontSize',18,'FontWeight','Bold')
print -dpng sfall4.png
%%%%% FIGURE(5) IS FOR INFO, NOT FOR MANUSCRIPT
% Check alt shifted softmax
figure(5)
clf
%subplot(2,2,2)
plot(smbdaltshift(a),errsmaltshift(a),'*')
hold on
ylim([0 80])
plot([min(smbdshift),max(smbdshift)],[min(smbdshift),max(smbdshift)],'r-','LineWidth',2) % ref slope
title('alt shifted')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')
%%%%% FIGURE(6) IS FOR MANUSCRIPT
figure(6) %%%% SOFTMAX comparisons for manuscript
clf
subplot(2,2,1)
plot(errsmshift(a),errsm1(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsm1(a)));
upper = max(max(errsmshift(a)),max(errsm1(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Basic')
xlabel('shifted')
ylabel('basic')
set(gca,'FontSize',FS)
axis equal
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])
subplot(2,2,2)
plot(errsmshift(a),errsm1b(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsm1b(a)));
upper = max(max(errsmshift(a)),max(errsm1b(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Alt.')
xlabel('shifted')
ylabel('alt')
set(gca,'FontSize',FS)
axis equal
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])
subplot(2,2,[3 4])
plot(errsmshift(a),errsmaltshift(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsmaltshift(a)));
upper = max(max(errsmshift(a)),max(errsmaltshift(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Alt. Shifted')
xlabel('shifted')
ylabel('alt. shifted')
set(gca,'FontSize',FS)
axis equal
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])
print -dpng softpic.png
%%% FIG 7 creates a figure for the manuscript %%%%%
figure(7)
clf
semilogy(sumchop(a),'ro');
hold on
semilogy(sumchopb(a),'bx')
%title('Thm 3.3 is red, Thm 3.4 is blue');
xlim([1 2025])
title('Softmax sum test: basic')
ylabel('sum(g)')
xlabel('Data points')
legend('basic','alternative','Location','NorthEast')
set(gca,'FontSize',16)
print -dpng sumpic2.png
%%% FIG 8 does not create a figure for the manuscript %%%%%
figure(8)
clf
semilogy(sumhigh(a),'ro')
title('single precision')
print -dpng sumpic1.png
%%% FIG 9 creates a figure for the manuscript %%%%%
figure(9)
clf
semilogy(sumchopshift(a),'ro')
hold on
semilogy(sumaltshift(a),'bx')
%title('Thm 4.3 is red, Thm 4.4 is blue');
xlim([1 2025])
title('Softmax sum test: shifted')
ylabel('sum(g)')
xlabel('Data points')
legend('shifted','alternative','Location','NorthEast')
set(gca,'FontSize',16)
print -dpng sumpic3.png