https://github.com/google/jax
Raw File
Tip revision: 2dfc5f962bfcad645860a4c1fece7280d4ac20bb authored by jax authors on 07 October 2020, 05:59:09 UTC
Merge pull request #4467 from google:update-pypi
Tip revision: 2dfc5f9
type_promotion.rst
.. _type-promotion:

Type promotion semantics
========================

JAX's type promotion rules (i.e., the result of
:func:`jax.numpy.promote_types` for each pair of types) are given by the
following table, where, for example

* "b1" means :code:`np.bool_`,
* "s2" means :code:`np.int16`,
* "u4" means :code:`np.uint32`,
* "bf" means :code:`np.bfloat16`,
* "f2" means :code:`np.float16`, and
* "c8" means :code:`np.complex128`.

.. raw:: html

    <style>
        #types table {
          border: 2px solid #aaa;
        }

        #types td, #types th {
          border: 1px solid #ddd;
          padding: 3px;
        }
        #types th {
          border-bottom: 1px solid #aaa;
        }
        #types tr:nth-child(even){background-color: #f2f2f2;}
        #types .d {
          background-color: #ccf2cc;
        }
        #types td:first-child{
          background-color: #f2f2f2;
          border-right: 1px solid #aaa;
          font-weight: bold;
        }
        #types tr:first-child{background-color: #f2f2f2;}
    </style>

    <table id="types">
    <tr><th></th><th>b1</th><th>u1</th><th>u2</th><th>u4</th><th>u8</th><th>i1</th><th>i2</th><th>i4</th><th>i8</th><th>bf</th><th>f2</th><th>f4</th><th>f8</th><th>c4</th><th>c8</th></tr>
    <tr><td>b1</td><td>b1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>u1</td><td>u1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>u2</td><td>u2</td><td>u2</td><td>u2</td><td>u4</td><td>u8</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td></tr>
    <tr><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td></tr>
    <tr><td>i1</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td>f8</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>i2</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td>f8</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>i4</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td>f8</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td></tr>
    <tr><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>f8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td></tr>
    <tr><td>bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">f4</td><td class="d">f4</td><td class="d">f8</td><td class="d">c4</td><td class="d">c8</td></tr>
    <tr><td>f2</td><td>f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f4</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>f4</td><td>f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td></tr>
    <tr><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">f8</td><td>f8</td><td>f8</td><td>f8</td><td>c8</td><td>c8</td></tr>
    <tr><td>c4</td><td>c4</td><td>c4</td><td>c4</td><td class="d">c4</td><td class="d">c4</td><td>c4</td><td>c4</td><td class="d">c4</td><td class="d">c4</td><td class="d">c4</td><td>c4</td><td>c4</td><td>c8</td><td>c4</td><td>c8</td></tr>
    <tr><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td class="d">c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td></tr>
    </table><p>

.. The table above was generated by the following Python code.
    import numpy as np
    import jax.numpy as jnp

    types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
             np.int8, np.int16, np.int32, np.int64,
             jnp.bfloat16, np.float16, np.float32, np.float64,
             np.complex64, np.complex128]

    def name(d):
      d = np.dtype(d)
      if d == np.dtype(jnp.bfloat16):
        return "bf"
      return "{}{}".format(
        d.kind,
        d.itemsize // 2 if np.issubdtype(d, np.complexfloating) else d.itemsize)

    out = "<tr><th></th>"
    for t in types:
      out += "<th>{}</th>".format(name(t))
    out += "</tr>\n"

    for t1 in types:
      out += "<tr><td>{}</td>".format(name(t1))
      for t2 in types:
        t = jnp.promote_types(t1, t2)
        different = jnp.bfloat16 in (t1, t2) or t != np.promote_types(t1, t2)
        out += "<td{}>{}</td>".format(" class=\"d\"" if different else "", name(t))
      out += "</tr>\n"

    print(out)

Jax's type promotion rules differ from those of NumPy, as given by
:func:`numpy.promote_types`, in those cells highlighted with a green background
in the table above. There are two key differences:

* when promoting an integer or boolean type against a floating-point or complex
  type, JAX always prefers the type of the floating-point or complex type.

  Accelerator devices, such as GPUs and TPUs, either pay a significant
  performance penalty to use 64-bit floating point types (GPUs) or do not
  support 64-bit floating point types at all (TPUs). Classic NumPy's promotion
  rules are too willing to overpromote to 64-bit types, which is problematic for
  a system designed to run on accelerators.

  JAX uses floating point promotion rules that are more suited to modern
  accelerator devices and are less aggressive about promoting floating point
  types. The promotion rules used by JAX for floating-point types are similar to
  those used by PyTorch.

* JAX supports the
  `bfloat16 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_
  non-standard 16-bit floating point type
  (:code:`jax.numpy.bfloat16`), which is useful for neural network training.
  The only notable promotion behavior is with respect to IEEE-754
  :code:`float16`, with which :code:`bfloat16` promotes to a :code:`float32`.
back to top