https://github.com/higham/logsumexp-softmax-tests
Raw File
Tip revision: a6e66ff1ff1d30800406481af2cf8720c6221691 authored by Nick Higham on 26 August 2020, 14:24:16 UTC
Updated reference in readme.
Tip revision: a6e66ff
chop_test_softmax.m
% CHOP_TEST_SOFTMAX.M
%
% USES CHOP TO LOWER THE PRECISION 
% FROM THE PACKAGE: 
% Simulating Low Precision Floating-Point Arithmetic,
% N J Higham and S. Pranesh, 
% MIMS EPrint 2019.4, March 2019; revised July 2019,
% to appear in SIAM J. Sci. Comput
%
% Data came from a LeCun style network that was trained on MINST data
% (GENERATED BY PRESOFTMAX.M)
%
% Here softmat is 10 by 2500, each column is the ten numbers
% that would go into the final softmax layer.
% Note: softmax is a single precision matrix
%
% May 2019
% Latest version June 2019, used for manuscript
% This one July 2019, now Aug 2019, now Sep 2019
%
% This version uses fp16 for the lower precision calculations.

load('data_presoftmax.mat')
%softmax in the single precision data matrix, 
%each column is a vector x of length 10

us = 5.96*1e-8;    % unit roundoff for single precision fp32
uh = 4.88*1e-4     % unit roundoff for half precision fp16  

options.format = 'h';    % to half precision
chop([],options);          % set up chop and now use x = chop(x);

n = 10;   % dimension of data

L = 2500  %  number of test vectors used: maximum possible is 2500
% Some test vectors are going to produce Infs--deal with that at the end
% Now initialize some storage
lsechopval = zeros(L,1); lsechopshiftval = zeros(L,1); 
lsebd1 = zeros(L,1); errlse1 = zeros(L,1);
lseshiftbd = zeros(L,1); errlseshift1 = zeros(L,1);
smbd1 = zeros(L,1); errsm1 = zeros(L,1);
smbd1b = zeros(L,1); errsm1b = zeros(L,1);
smbdshift = zeros(L,1); errsmshift = zeros(L,1);
smbdaltshift = zeros(L,1); errsmaltshift = zeros(L,1);
sumchop = zeros(L,1); sumchopb = zeros(L,1);
sumhigh = zeros(L,1);
sumchopshift = zeros(L,1); sumaltshift = zeros(L,1);

FS = 12; % FontSize

for k = 1:L 
    k
    xvec = softmat(:,k);     %kth data point
    xchop = chop(xvec) ;
    %%%%%%%% test on this vector %%%%%%%
    [lsechop, smchop, smchopb] = lse_chop(xchop);   %chopped: basic lse and two versions of softmax
    %%% Tried the version below first: uses full xvec to get reference solution
    %[lsehigh, smhigh] = lse_basichigh(xvec);   %single (original) prec alternative (taken to be exact solution)
    %%% Changed to the version below: 
    %%% uses chopped vector (but works in the higher, single, precision)
    [lsehigh, smhigh] = lse_basichigh(xchop);   %single (original) prec alternative (taken to be exact solution)
    
    [lsechopshift, smchopshift, smaltshift] = lse_chopshift(xchop);  % chopped version of shifted alg
    
    lsechopval(k) = lsechop;  % store in order to test for Infs
    lsechopshiftval(k) = lsechopshift; % store in order to test for Infs
    
    % LSE: leading order term in error bound from Thm3.2
    lsebd1(k) = 1 + (n+1)/abs(lsehigh);
    errlse1(k) = abs(lsechop-lsehigh)/(uh*abs(lsehigh));
    
    % LSE: leading order term from Thm 4.2
    lseshiftbd1(k) = abs(lsehigh + n - min(xvec))/abs(lsehigh);
    errlseshift1(k) = abs(lsechopshift-lsehigh)/(uh*abs(lsehigh));
    
    % SOFTMAX: leading order term in error bound from Thm3.3
    smbd1(k) = (n+3);
    errsm1(k) = max(abs(smchop - smhigh))/(uh*max(abs(smhigh)));    
    
    % SOFTMAX: leading order term in error bound from Thm3.4 for
    % alternative alg
    smbd1b(k) = (abs(lsehigh) + max(abs(xvec - lsehigh)) + n + 2);
    errsm1b(k) = max(abs(smchopb - smhigh))/(uh*max(abs(smhigh)));
    
    % SOFTMAX: leading order term in error bound from Thm 4.3 for
    % shifted alg
    smbdshift(k) = n + 2 + 2*(max(xvec) - min(xvec));
    errsmshift(k) = max(abs(smchopshift - smhigh))/(uh*max(abs(smhigh))); 
    
    % SOFTMAX: leading order term in error bound from Thm 4.4 for
    % alternative version of shifted alg
    smbdaltshift(k) = 1 + max(abs(xvec - lsehigh)) + abs(lsehigh + n - min(xvec));
    errsmaltshift(k) = max(abs(smaltshift - smhigh))/(uh*max(abs(smhigh))); 

    % compute sum of softmax entries in each case
    sumchop(k) = sum(smchop);
    sumchopb(k) = sum(smchopb);
    sumhigh(k) = sum(smhigh);
    sumchopshift(k) = sum(smchopshift);
    sumaltshift(k) = sum(smaltshift);
    
end

% IGNORE THE CASES WHERE CHOPPED LSE EVALUTED TO INF
a = find(lsechopval ~= Inf);

a2 =  find(lsechopval == Inf);
numInfs = length(a2)                   %overflows in half prec basic LSE: gives 475
 
a3 =  find(lsechopshiftval == Inf);
numInfs_shift = length(a3)           %overflows in half prec shifted LSE: gives 0

% So we will only plot info for data vectors indexed by a
% (i.e., for input where all algorithms avoided overflow)

%% FIGURE 1 PRODUCES A PICTURE FOR THE MANUCSRIPT
figure(1)      %%% LSE basic alg  
clf
subplot(2,2,1)
plot(lsebd1(a),errlse1(a),'*')
hold on
ylim([0 4]), yticks(0:4)
xlim([1.9 6]), xticks(2:6)
lb1 = min(lsebd1(a))    % 1,9 is a lower bound in the smallest x component that we plot 
plot([min(lsebd1(a)),max(lsebd1(a))],[min(lsebd1(a)),max(lsebd1(a))],'r-','LineWidth',2) % ref slope
title('Basic LSE')
ylabel('f.p. error')
xlabel('error bound')
set(gca,'FontSize',FS)
axis square
set(gca,'TickLength',[0.02 0.025])

subplot(2,2,2)
plot(lseshiftbd1(a),errlseshift1(a),'*')
hold on
ylim([0 4]), yticks(0:4)
xlim([1.9 6]), xticks(0:6)
lb2 = min(lseshiftbd1)  % 1.9 is a lower bound in the smallest x component that we plot 
plot([min(lseshiftbd1),max(lseshiftbd1)],[min(lseshiftbd1),max(lseshiftbd1)],'r-','LineWidth',2) % ref slope
title('Shifted LSE')
ylabel('f.p. error')
xlabel('error bound')
set(gca,'FontSize',FS)
axis square
set(gca,'TickLength',[0.02 0.025])

subplot(2,2,[3 4])
% subplot(2,2,[2 4])
plot(errlse1(a),errlseshift1(a),'*')
hold on
%plot([min(errlse1(a)),max(errlse1(a))],[min(errlse1(a)),max(errlse1(a))],'r-','LineWidth',2) % ref slope
xlabel('basic LSE')
ylabel('shifted LSE')
ylim([0 1.25]), yticks(0:0.25:1.25)
xlim([0 1.25]), xticks(0:0.5:1.25)
% axis equal
axis square
title('Basic vs Shifted')
set(gca,'FontSize',FS)
set(gca,'TickLength',[0.02 0.025])

print -dpng lsepic.png

lse_ratio = errlse1(a)./errlseshift1(a);
lse_rat_min = min (lse_ratio)
lse_rat_max = max(lse_ratio)
lse_rat_mean = mean(lse_ratio)
lse_rat_std = std(lse_ratio)/sqrt(length(a))
%%%%% FIGURE(2) IS FOR INFO, NOT FOR MANUSCRIPT
figure(2)   % another comparison of the errors in the two methods 
clf
%histogram(log10(lse_ratio))
rat_sort = sort(log10(lse_ratio),'ascend');
plot(rat_sort,'*-')
title('Log ratio of errors in LSE shifted LSE')
print -dpng lseratiopic.png

% seem to be lots of ones
rat = errlse1(a)./errlseshift1(a);   % many of these are seen to be exactly 1 in f.p. 
[a1,a2] = find(rat == 1);
numones = length(a1)
ratio = numones/length(rat)


%%%%% FIGURE(3) IS FOR INFO, NOT FOR MANUSCRIPT
figure(3)      %%% SOFTMAX basic alg : upper bd is independent of x  
clf
histogram(errsm1(a))
title('Basic SOFTMAX')
set(gca,'FontSize',24,'FontWeight','Bold')
print -dpng sf1.png

%%%%% FIGURE(4) IS FOR INFO, NOT FOR MANUSCRIPT
figure(4)    %%%% SOFTMAX alternative alg
clf
subplot(2,2,1)
plot(smbd1b(a),errsm1b(a),'*')
hold on
ylim([0 50])
plot([min(smbd1b(a)),max(smbd1b(a))],[min(smbd1b(a)),max(smbd1b(a))],'r-','LineWidth',2) % ref slope
title('alt.')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')

subplot(2,2,2)
plot(smbdshift(a),errsmshift(a),'*')
hold on
ylim([0 80])
plot([min(smbdshift),max(smbdshift)],[min(smbdshift),max(smbdshift)],'r-','LineWidth',2) % ref slope
title('shifted')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')

subplot(2,2,4)
plot(errsmshift(a),errsm1b(a),'*')
hold on
ylim([0 15])
lower = min(min(errsmshift(a)),min(errsm1b(a)));
upper = max(max(errsmshift(a)),max(errsm1b(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('shifted v alt.')
xlabel('shifted')
ylabel('alt')
set(gca,'FontSize',18,'FontWeight','Bold')

subplot(2,2,3)
plot(errsmshift(a),errsm1(a),'*')
hold on
ylim([0 10])
lower = min(min(errsmshift(a)),min(errsm1(a)));
upper = max(max(errsmshift(a)),max(errsm1(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('shifted v basic')
xlabel('shifted')
ylabel('basic')
set(gca,'FontSize',18,'FontWeight','Bold')
print -dpng sfall4.png

%%%%% FIGURE(5) IS FOR INFO, NOT FOR MANUSCRIPT
% Check alt shifted softmax
figure(5)
clf
%subplot(2,2,2)
plot(smbdaltshift(a),errsmaltshift(a),'*')
hold on
ylim([0 80])
plot([min(smbdshift),max(smbdshift)],[min(smbdshift),max(smbdshift)],'r-','LineWidth',2) % ref slope
title('alt shifted')
xlabel('error bound')
ylabel('f.p. error')
set(gca,'FontSize',18,'FontWeight','Bold')


%%%%% FIGURE(6) IS FOR MANUSCRIPT
figure(6)    %%%% SOFTMAX comparisons for manuscript
clf
subplot(2,2,1)
plot(errsmshift(a),errsm1(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsm1(a)));
upper = max(max(errsmshift(a)),max(errsm1(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Basic')
xlabel('shifted')
ylabel('basic')
set(gca,'FontSize',FS)
axis equal
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])

subplot(2,2,2)
plot(errsmshift(a),errsm1b(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsm1b(a)));
upper = max(max(errsmshift(a)),max(errsm1b(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Alt.')
xlabel('shifted')
ylabel('alt')
set(gca,'FontSize',FS)
axis equal 
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])

subplot(2,2,[3 4])
plot(errsmshift(a),errsmaltshift(a),'*')
hold on
lower = min(min(errsmshift(a)),min(errsmaltshift(a)));
upper = max(max(errsmshift(a)),max(errsmaltshift(a)));
plot([lower,upper],[lower,upper],'r-','LineWidth',2) % ref slope
title('Shifted v Alt. Shifted')
xlabel('shifted')
ylabel('alt. shifted')
set(gca,'FontSize',FS)
axis equal 
axis square
ylim([0 10])
xlim([0 10])
set(gca,'TickLength',[0.02 0.025])

print -dpng softpic.png


%%% FIG 7 creates a figure for the manuscript %%%%%
figure(7)
clf
semilogy(sumchop(a),'ro');
hold on
semilogy(sumchopb(a),'bx')
%title('Thm 3.3 is red, Thm 3.4 is blue');
xlim([1 2025])
title('Softmax sum test: basic')
ylabel('sum(g)')
xlabel('Data points')
legend('basic','alternative','Location','NorthEast')
set(gca,'FontSize',16)
print -dpng sumpic2.png

%%% FIG 8 does not create a figure for the manuscript %%%%%
figure(8)
clf
semilogy(sumhigh(a),'ro')
title('single precision')
print -dpng sumpic1.png

%%% FIG 9 creates a figure for the manuscript %%%%%
figure(9)
clf
semilogy(sumchopshift(a),'ro')
hold on
semilogy(sumaltshift(a),'bx')
%title('Thm 4.3 is red, Thm 4.4 is blue');
xlim([1 2025])
title('Softmax sum test: shifted')
ylabel('sum(g)')
xlabel('Data points')
legend('shifted','alternative','Location','NorthEast')
set(gca,'FontSize',16)
print -dpng sumpic3.png
back to top