https://github.com/higham/logsumexp-softmax-tests
Raw File
Tip revision: a6e66ff1ff1d30800406481af2cf8720c6221691 authored by Nick Higham on 26 August 2020, 14:24:16 UTC
Updated reference in readme.
Tip revision: a6e66ff
roundit.m
function [y,options] = roundit(x,options)
%ROUNDIT   Round a matrix to integer entries, with various options.
%   y = ROUNDIT(X,p,options) rounds the matrix X to have integer
%   entries, as specified by options.round:
%     1: round to nearest integer using round to even to break ties
%        (the default),
%     2: round towards plus infinity (round up),
%     3: round towards minus infinity (round down,
%     4: round towards zero,
%     5: stochastic rounding - round to the next larger or next smaller
%        integer with probability equal to 1 minus the distance to 
%        those integers,
%     6: stochastic rounding - round to the next larger or next smaller
%        integer with equal probahility.
%   For stochastic rounding, exact integers are not changed.
%   If options.flip = 1 (default 0) then each element of the rounded result 
%   has, with probability options.p (default 0.5), a randomly chosen bit
%   (in its binary representation) flipped. 
%   The integers are assumed to have moduli on the interval [0,2^options.t-1],
%   which defines the range of bits in their binary representation.

if nargin < 2 || isempty(options) 
    options.round = 1; options.flip = 0; options.p = 0.5;
end    
if ~isfield(options,'round'), options.round = 1; end
if ~isfield(options,'flip'), options.flip = 0; end
if ~isfield(options,'p'), options.p = 0.5; end

mysign = @(x) sign(x) + (x==0); % mysign(0) = 1;

switch options.round
  
  case 1
    y = abs(x);
    % Built-in ROUND function rounds ties away from zero.
    % Round to nearest integer using round to even to break ties.
    u = round(y - (rem(y,2) == 0.5));
    u(find(u == -1)) = 0; % Special case, negative argument to ROUND.
    y = sign(x).*u; 

  case 2
    % Round towards plus infinity.
    y = ceil(x); 

  case 3
    % Round towards minus infinity.
    y = floor(x); 

  case 4
    % Round towards zero.
    if x >= 0, y = floor(x); else y = ceil(x); end

  case {5, 6}

    % Stochastic rounding.
    y = abs(x); 
    frac = y - floor(y);
    k = find(frac ~= 0);
    if isempty(k)
       y = x; 
    else   
      rnd = rand(length(k),1);
      vals = frac(k);  vals = vals(:);

      switch options.round
        case 5 % Round up or down with probability prop. to distance.
               j = (rnd <= vals);
        case 6 % Round up or down with equal probability.       
               j = (rnd <= 0.5);
      end      
      y(k(j)) = ceil(y(k(j)));
      y(k(~j)) = floor(y(k(~j)));
      y = mysign(x).*y; 
   end   
   
  otherwise
    error('Unsupported value of options.round.')  
               
end

if options.flip
    
   temp = rand(size(y));
   k = find(temp <= options.p); % Indices of elements to have a bit flipped.
   if ~isempty(k)
      u = abs(y(k));
      % Random bit flip in significand.
      % b defines which bit (1 to p-1) to flip in each element of y.
      b = randi(options.t-1,length(k),1);
      % Flip selected bits.
      u = bitxor(u,2.^(b-1));
      y(k) = mysign(y(k)).*u; 
   end

end
back to top