https://github.com/higham/logsumexp-softmax-tests
Tip revision: a6e66ff1ff1d30800406481af2cf8720c6221691 authored by Nick Higham on 26 August 2020, 14:24:16 UTC
Updated reference in readme.
Updated reference in readme.
Tip revision: a6e66ff
chop.m
function [c,options] = chop(x,options)
%CHOP Round matrix elements to lower precision.
% CHOP(X,options) is the matrix obtained by rounding the elements of
% the array X to a lower precision arithmetic with one of several
% forms of rounding. X should be single precision or double precision
% and the output will have the same type.
% The arithmetic format is specified by options.format, which is one of
% 'b', 'bfloat16' - bfloat16,
% 'h', 'half', 'fp16' - IEEE half precision (the default),
% 's', 'single', 'fp32' - IEEE single precision,
% 'd', 'double', 'fp64' - IEEE double precision,
% 'c', 'custom' - custom format.
% In the last case the (base 2) format is defined by
% options.params, which is a 2-vector [t, emax] where t is the
% number of bits in the significand (including the hidden bit) and
% emax is the maximum value of the exponent. The values of t and emax
% are built-in for b, h, s, d and will automatically be returned in
% options.params.
% options.subnormal specifies whether subnormal numbers are supported
% (if they are not, subnormals are flushed to zero):
% 0 = do not support subnormals (the default for bfloat16),
% 1 = support subnormals (the default for fp16, fp32 and fp64).
% The form of rounding is specified by options.round:
% 1: round to nearest using round to even last bit to break ties
% (the default);
% 2: round towards plus infinity (round up);
% 3: round towards minus infinity (round down);
% 4: round towards zero;
% 5: stochastic rounding - round to the next larger or next smaller
% f.p. (floating-point) number with probability proportional to
% 1 minus the distance to those f.p. numbers;
% 6: stochastic rounding - round to the next larger or next smaller
% f.p. number with equal probability.
% For stochastic rounding, exact f.p. numbers are not changed.
% If options.flip = 1 (default 0) then each element of the rounded result
% has, with probability options.p (default 0.5), a randomly chosen bit
% in its significand flipped.
% On the first call: if options is omitted or only partially specified
% the defaults stated above are used.
% On subsequent calls: if options is omitted or empty then the values used
% in the previous call are re-used; for any missing fields the
% default is used.
% The options structure is stored internally in a persistent variable
% and can be obtained with [~,options] = CHOP.
% References:
% [1] IEEE Standard for Floating-Point Arithmetic, IEEE Std 754-2008 (revision
% of IEEE Std 754-1985), 58, IEEE Computer Society, 2008; pages 8,
% 13. https://ieeexplore.ieee.org/document/461093
% [2] Intel Corporation, BFLOAT16---hardware numerics definition, Nov. 2018,
% White paper. Document number 338302-001US.
% https://software.intel.com/en-us/download/bfloat16-hardware-numerics-definition
persistent fpopts
if isempty(fpopts) && (nargin <= 1 || (nargin == 2 && isempty(options)))
fpopts.format = 'h'; fpopts.subnormal = 1;
fpopts.round = 1; fpopts.flip = 0; fpopts.p = 0.5;
elseif nargin == 2 && ~isempty(options)
% This is not the first call, but fpopts might have all empty fields.
if ~isfield(options,'format') || ...
isfield(options,'format') && isempty(options.format)
options.format = 'h';
end
fpopts.format = options.format;
if isfield(options,'subnormal') && ~isempty(options.subnormal)
fpopts.subnormal = options.subnormal;
else
if ismember(fpopts.format, {'b','bfloat16'})
fpopts.subnormal = 0;
else
fpopts.subnormal = 1;
end
end
if isfield(options,'round') && ~isempty(options.round)
fpopts.round = options.round;
else
fpopts.round = 1;
end
if isfield(options,'flip') && ~isempty(options.flip)
fpopts.flip = options.flip;
else
fpopts.flip = 0;
end
if isfield(options,'p') && ~isempty(options.p)
fpopts.p = options.p;
else
fpopts.p = 0.5;
end
end
if ismember(fpopts.format, {'h','half','fp16','b','bfloat16','s', ...
'single','fp32','d','double','fp64'})
if ismember(fpopts.format, {'h','half','fp16'})
% Significand: 10 bits plus 1 hidden. Exponent: 5 bits.
t = 11; emax = 15;
elseif ismember(fpopts.format, {'b','bfloat16'})
% Significand: 7 bits plus 1 hidden. Exponent: 8 bits.
t = 8; emax = 127;
elseif ismember(fpopts.format, {'s','single','fp32'})
% Significand: 23 bits plus 1 hidden. Exponent: 8 bits.
t = 24; emax = 127;
elseif ismember(fpopts.format, {'d','double','fp64'})
% Significand: 52 bits plus 1 hidden. Exponent: 11 bits.
t = 53; emax = 1023;
end
fpopts.params = [t emax];
elseif ismember(fpopts.format, {'c','custom'})
if nargin == 2 && ~isempty(options)
if isfield(options,'params') && ~isempty(options.params)
fpopts.params(1) = options.params(1);
fpopts.params(2) = options.params(2);
% Need "p_2 \ge 2p_1 + 2" to avoid double rounding probolems.
if isa(x,'single') && (fpopts.params(1) > 11)
error(['Precision of the custom format must be less than ' ...
'12 if working in single.']);
elseif isa(x,'double') && (fpopts.params(1) > 25)
error(['Precision of the custom format must be less than ' ...
'26 if working in double.']);
end
end
elseif ~isfield(fpopts,'params') || isempty(fpopts.params)
error('Must specify options.params with options.format = ''c''.')
end
t = fpopts.params(1); emax = fpopts.params(2);
else
error('Unrecognized format.')
end
if nargout == 2, options = fpopts; end
if nargin == 0 || isempty(x), if nargout >= 1, c = []; end, return, end
if fpopts.flip == 1, fpopts.t = t; end
emin = 1-emax; % Exponent of smallest normalized number.
xmin = 2^emin; % Smallest positive normalized number.
emins = emin + 1 - t; % Exponent of smallest positive subnormal number.
xmins = 2^emins; % Smallest positive subnormal number.
xmax = 2^emax * (2-2^(1-t));
% Use the representation:
% x = 2^e * d_1.d_2...d_{t-1} * s, s = 1 or -1.
c = x;
e = floor(log2(abs(x)));
ktemp = (e < emin & e >= emins);
k_sub = find(ktemp);
k_norm = find(~ktemp);
c(k_norm) = pow2(roundit(pow2(x(k_norm), t-1-e(k_norm)), fpopts), ...
e(k_norm)-(t-1));
if ~isempty(k_sub)
t1 = t - max(emin-e(k_sub),0);
c(k_sub) = pow2(roundit( pow2(x(k_sub), t1-1-e(k_sub)), fpopts ), ...
e(k_sub)-(t1-1));
if fpopts.subnormal == 0
c(k_sub) = 0; % Flush subnormals to zero.
end
end
% Any number large than xboundary rounds to inf [1, p. 16].
xboundary = 2^emax * (2-(1/2)*2^(1-t));
c(find(x >= xboundary)) = inf; % Overflow to +inf.
c(find(x <= -xboundary)) = -inf; % Overflow to -inf.
c(find(abs(x) < xmins)) = 0; % Underflow to zero.