https://github.com/google/jax
Raw File
Tip revision: 3e5466c2f9c21edfad8a7ea1e973aec1a763d64c authored by jax authors on 18 November 2021, 17:19:17 UTC
Internal change.
Tip revision: 3e5466c
CHANGELOG.md
# Change log

Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).

<!--
Remember to align the itemized text with the first line of an item within a list.

PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->

## jaxlib 0.1.75 (Unreleased)

## jax 0.2.26 (Unreleased)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.25...main).

## jaxlib 0.1.74 (Nov 17, 2021)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
  the host, which is usually slower.
* Added experimental MLIR Python bindings for use by JAX.

## jax 0.2.25 (Nov 10, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25).

* New features:
  * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
  * `jax.random.permutation` supports new `independent` keyword argument
    ({jax-issue}`#8430`)
* Breaking changes
  * Moved `jax.experimental.stax` to `jax.example_libraries.stax`
  * Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`

* New features:
  * Added `jax.lax.linalg.qdwh`.

## jax 0.2.24 (Oct 19, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).

* New features:
  * `jax.random.choice` and `jax.random.permutation` now support
    multidimensional arrays and an optional `axis` argument ({jax-issue}`#8158`)
* Breaking changes:
  * `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
    (see {jax-issue}`#7737`)

## jaxlib 0.1.73 (Oct 18, 2021)

* Multiple cuDNN versions are now supported for jaxlib GPU `cuda11` wheels.
  * cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
    installation is new enough, since it supports additional functionality.
  * cuDNN 8.0.5 or newer.

* Breaking changes:
  * The install commands for GPU jaxlib are as follows:

    ```bash
    pip install --upgrade pip

    # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
    pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

    # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
    pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

    # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
    pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
    ```

## jax 0.2.22 (Oct 12, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
  * Static arguments to `jax.pmap` must now be hashable.

    Unhashable static arguments have long been disallowed on `jax.jit`, but they
    were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
    arguments using object identity.

    This behavior is a footgun, since comparing arguments using
    object identity leads to recompilation each time the object identity
    changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
    wants to compare static arguments by object identity, they can define
    `__hash__` and `__eq__` methods on their objects that do that, or wrap their
    objects in an object that has those operations with object identity
    semantics. Another option is to use `functools.partial` to encapsulate the
    unhashable static arguments into the function object.
  * `jax.util.partial` was an accidental export that has now been removed. Use
    `functools.partial` from the Python standard library instead.
* Deprecations
  * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are
    deprecated and will be removed in a future JAX release. Please use
    [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
    instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a
    `DeprecationWarning`.
* New features:
  * An optimized C++ code-path improving the dispatch time for `pmap` is now the
    default when using jaxlib 0.1.72 or newer. The feature can be disabled using
    the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
  * `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`)

## jaxlib 0.1.72 (Oct 12, 2021)
  * Breaking changes:
    * Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
      CUDA 11.1+.
  * Bug fixes:
    * Fixes https://github.com/google/jax/issues/7461, which caused wrong
      outputs on all platforms due to incorrect buffer aliasing inside the XLA
      compiler.

## jax 0.2.21 (Sept 23, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
  * `jax.api` has been removed. Functions that were available as `jax.api.*`
    were aliases for functions in `jax.*`; please use the functions in
    `jax.*` instead.
  * `jax.partial`, and `jax.lax.partial` were accidental exports that have now
    been removed. Use `functools.partial` from the Python standard library
    instead.
  * Boolean scalar indices now raise a `TypeError`; previously this silently
    returned wrong results ({jax-issue}`#7925`).
  * Many more `jax.numpy` functions now require array-like inputs, and will error
    if passed a list ({jax-issue}`#7747` {jax-issue}`#7802` {jax-issue}`#7907`).
    See {jax-issue}`#7737` for a discussion of the rationale behind this change.
  * When inside a transformation such as `jax.jit`, `jax.numpy.array` always
    stages the array it produces into the traced computation. Previously
    `jax.numpy.array` would sometimes produce a on-device array, even under
    a `jax.jit` decorator. This change may break code that used JAX arrays to
    perform shape or index computations that must be known statically; the
    workaround is to perform such computations using classic NumPy arrays
    instead.
  * `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
    means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
    now return `False` ({jax-issue}`7927`).
* New features:
  * Added {func}`jax.numpy.insert` implementation ({jax-issue}`#7936`).

## jax 0.2.20 (Sept 2, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes
  * `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
  * `jnp.unique` and other set-like operations now require array-like inputs
    ({jax-issue}`#7662`)

## jaxlib 0.1.71 (Sep 1, 2021)
* Breaking changes:
  * Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
    CUDA 10.2 and CUDA 11.1+.

## jax 0.2.19 (Aug 12, 2021)
* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes:
  * Support for NumPy 1.17 has been dropped, per the
    [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
    Please upgrade to a supported NumPy version.
  * The `jit` decorator has been added around the implementation of a number of
    operators on JAX arrays. This speeds up dispatch times for common
    operators such as `+`.

    This change should largely be transparent to most users. However, there is
    one known behavioral change, which is that large integer constants may now
    produce an error when passed directly to a JAX operator
    (e.g., `x + 2**40`). The workaround is to cast the constant to an
    explicit type (e.g., `np.float64(2**40)`).
* New features:
  * Improved the support for shape polymorphism in jax2tf for operations that
    need to use a dimension size in array computation, e.g., `jnp.mean`.
    ({jax-issue}`#7317`)
* Bug fixes:
  * Some leaked trace errors from the previous release ({jax-issue}`#7613`)

## jaxlib 0.1.70 (Aug 9, 2021)
* Breaking changes:
  * Support for Python 3.6 has been dropped, per the
    [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
    Please upgrade to a supported Python version.
  * Support for NumPy 1.17 has been dropped, per the
    [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
    Please upgrade to a supported NumPy version.

  * The host_callback mechanism now uses one thread per local device for
    making the calls to the Python callbacks. Previously there was a single
    thread for all devices. This means that the callbacks may now be called
    interleaved. The callbacks corresponding to one device will still be
    called in sequence.

## jax 0.2.18 (July 21 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18).

* Breaking changes:
  * Support for Python 3.6 has been dropped, per the
    [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
    Please upgrade to a supported Python version.
  * The minimum jaxlib version is now 0.1.69.
  * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
    removed.

* New features:
  * Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).

* Bug fixes:
  * Tightened the checks for lax.argmin and lax.argmax to ensure they are
    not used with an invalid `axis` value, or with an empty reduction dimension.
    ({jax-issue}`#7196`)


## jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.

## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
  * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
    to work around #7229, which caused wrong outputs on CPU due to a concurrency
    problem.
* New features:
  * New SciPy function {py:func}`jax.scipy.special.sph_harm`.
  * Reverse-mode autodiff functions ({func}`jax.grad`,
    {func}`jax.value_and_grad`, {func}`jax.vjp`, and
    {func}`jax.linear_transpose`) support a parameter that indicates which named
    axes should be summed over in the backward pass if they were broadcasted
    over in the forward pass. This enables use of these APIs in a
    non-per-example way inside maps (initially only
    {func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).


## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).

## jax 0.2.15 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
  * [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend
    with significant dispatch performance improvements on CPU.
  * The {func}`jax2tf.convert` supports inequalities and min/max for booleans
    ({jax-issue}`#6956`).
  * New SciPy function {py:func}`jax.scipy.special.lpmn_values`.

* Breaking changes:
  * Support for NumPy 1.16 has been dropped, per the
    [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).

* Bug fixes:
  * Fixed bug that prevented round-tripping from JAX to TF and back:
    `jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).

## jaxlib 0.1.68 (June 23 2021)
* Bug fixes:
  * Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to
    CPU.

## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
  * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
  * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
    tracebacks.
  * A new traceback filtering mode using `__tracebackhide__` is now enabled by
    default in sufficiently recent versions of IPython.
  * The {func}`jax2tf.convert` supports shape polymorphism even when the
    unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
    ({jax-issue}`#6827`).
  * The {func}`jax2tf.convert` generates custom attributes with location information
   in TF ops. The code that XLA generates after jax2tf
   has the same location information as JAX/XLA.
  * New SciPy function {py:func}`jax.scipy.special.lpmn`.

* Bug fixes:
  * The {func}`jax2tf.convert` now ensures that it uses the same typing rules
    for Python scalars and for choosing 32-bit vs. 64-bit computations
    as JAX ({jax-issue}`#6883`).
  * The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
    properly to apply only during the just-in-time conversion
    ({jax-issue}`#6720`).
  * The {func}`jax2tf.convert` now converts `lax.dot_general` using the
    `XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
    ({jax-issue}`#6717`).
  * The {func}`jax2tf.convert` now has support for inequality comparisons and
    min/max for complex numbers ({jax-issue}`#6892`).

## jaxlib 0.1.67 (May 17 2021)

## jaxlib 0.1.66 (May 11 2021)

* New features:
  * CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

    NVidia now promises compatibility between CUDA minor releases starting with
    CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
    is compatible with CUDA 11.2 and 11.3.

    There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
    the CUDA 11.1 wheel for those versions (cuda111).
  * Jaxlib now bundles `libdevice.10.bc` in CUDA wheels. There should be no need
    to point JAX to a CUDA installation to find this file.
  * Added automatic support for static keyword arguments to the {func}`jit`
    implementation.
  * Added support for pretransformation exception traces.
  * Initial support for pruning unused arguments from {func}`jit` -transformed
    computations.
    Pruning is still a work in progress.
  * Improved the string representation of {class}`PyTreeDef` objects.
  * Added support for XLA's variadic ReduceWindow.
* Bug fixes:
  * Fixed a bug in the remote cloud TPU support when large numbers of arguments
    are passed to a computation.
  * Fix a bug that meant that JAX garbage collection was not triggered by
    {func}`jit` transformed functions.

## jax 0.2.13 (May 3 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
  * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
    keyword arguments. A new `static_argnames` option has been added to specify
    keyword arguments as static.
  * {func}`jax.nonzero` has a new optional `size` argument that allows it to
    be used within `jit` ({jax-issue}`#6501`)
  * {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
  * {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
  * Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
    eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
    present.
  * The order of the filtered and unfiltered stack traces in exceptions has been
    changed. The traceback attached to an exception thrown from JAX-transformed
    code is now filtered, with an `UnfilteredStackTrace` exception
    containing the original trace as the `__cause__` of the filtered exception.
    Filtered stack traces now also work with Python 3.6.
  * If an exception is thrown by code that has been transformed by reverse-mode
    automatic differentiation, JAX now attempts to attach as a `__cause__` of
    the exception a `JaxStackTraceBeforeTransformation` object that contains the
    stack trace that created the original operation in the forward pass.
    Requires jaxlib 0.1.66.

* Breaking changes:
  * The following function names have changed. There are still aliases, so this
    should not break existing code, but the aliases will eventually be removed
    so please change your code.
    * `host_id` --> {func}`~jax.process_index`
    * `host_count` --> {func}`~jax.process_count`
    * `host_ids` --> `range(jax.process_count())`
  * Similarly, the argument to {func}`~jax.local_devices` has been renamed from
    `host_id` to `process_index`.
  * Arguments to {func}`jax.jit` other than the function are now marked as
    keyword-only. This change is to prevent accidental breakage when arguments
    are added to `jit`.
* Bug fixes:
  * The {func}`jax2tf.convert` now works in presence of gradients for functions
    with integer inputs ({jax-issue}`#6360`).
  * Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
    `tf.Variable` ({jax-issue}`#6572`).

## jaxlib 0.1.65 (April 7 2021)

## jax 0.2.12 (April 1 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* New features
  * New profiling APIs: {func}`jax.profiler.start_trace`,
    {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
  * {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
  * The minimum jaxlib version is now 0.1.64.
  * Some profiler APIs names have been changed. There are still aliases, so this
    should not break existing code, but the aliases will eventually be removed
    so please change your code.
    * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
    * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
    * `trace_function` --> {func}`~jax.profiler.annotate_function`
  * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/design_notes/omnistaging.md)
    for more information.
  * Python integers larger than the maximum `int64` value will now lead to an overflow
    in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
  * Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
    `OverflowError` rather than having their value silently truncated.
* Bug fixes:
  * `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
  * {func}`jax.random.randint` clips rather than wraps of out-of-bounds limits, and can now generate
    integers in the full range of the specified dtype ({jax-issue}`#5868`)

## jax 0.2.11 (March 23 2021)

* [GitHub
  commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features:
  * [#6112](https://github.com/google/jax/pull/6112) added context managers:
    `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
    `jax.debug_infs`, `jax.log_compiles`.
  * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`

* Bug fixes:
  * [#6136](https://github.com/google/jax/pull/6136) generalized
    `jax.flatten_util.ravel_pytree` to handle integer dtypes.
  * [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
    some constants like `enum.IntEnums`
  * [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with
    incomplete beta functions
  * [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
    tracing
  * [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
    converting some large Python integers to floats
* Breaking changes:
  * The minimum jaxlib version is now 0.1.62.


## jaxlib 0.1.64 (March 18 2021)

## jaxlib 0.1.63 (March 17 2021)

## jax 0.2.10 (March 5 2021)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
  * {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
  * {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
  * Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
    from JAX ({jax-issue}`#5627`)
    and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
  * Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
  * {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
* Breaking changes:
  * JAX's promotion rules were adjusted to make promotion more consistent and
    invariant to JIT. In particular, binary operations can now result in weakly-typed
    values when appropriate. The main user-visible effect of the change is that
    some operations result in outputs of different precision than before; for
    example the expression `jnp.bfloat16(1) + 0.1 * jnp.arange(10)`
    previously returned a `float64` array, and now returns a `bfloat16` array.
    JAX's type promotion behavior is described at {ref}`type-promotion`.
  * {func}`jax.numpy.linspace` now computes the floor of integer values, i.e.,
    rounding towards -inf rather than 0. This change was made to match NumPy
    1.20.0.
  * {func}`jax.numpy.i0` no longer accepts complex numbers. Previously the
    function computed the absolute value of complex arguments. This change was
    made to match the semantics of NumPy 1.20.0.
  * Several {mod}`jax.numpy` functions no longer accept tuples or lists in place
    of array arguments: {func}`jax.numpy.pad`, :func`jax.numpy.ravel`,
    {func}`jax.numpy.repeat`, {func}`jax.numpy.reshape`.
    In general, {mod}`jax.numpy` functions should be used with scalars or array arguments.

## jaxlib 0.1.62 (March 9 2021)

* New features:
  * jaxlib wheels are now built to require AVX instructions on x86-64 machines
    by default. If you want to use JAX on a machine that doesn't support AVX,
    you can build a jaxlib from source using the `--target_cpu_features` flag
    to `build.py`. `--target_cpu_features` also replaces
    `--enable_march_native`.

## jaxlib 0.1.61 (February 12 2021)

## jaxlib 0.1.60 (Febuary 3 2021)

* Bug fixes:
  * Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
    memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
  * `bool`, `int8`, and `uint8` are now considered safe to cast to
    `bfloat16` NumPy extension type.

## jax 0.2.9 (January 26 2021)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
  * Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
    error checking and error messages.
  * Add {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`.
    These are context managers which allow X64 mode to be temporarily enabled/disabled
    within a session.
* Breaking changes:
  * {func}`jax.ops.segment_sum` now drops segment IDs that are out of range rather
    than wrapping them into the segment ID space. This was done for performance
    reasons.

## jaxlib 0.1.59 (January 15 2021)

## jax 0.2.8 (January 12 2021)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
  * Add {func}`jax.closure_convert` for use with higher-order custom
    derivative functions. ({jax-issue}`#5244`)
  * Add {func}`jax.experimental.host_callback.call` to call a custom Python
    function on the host and return a result to the device computation.
    ({jax-issue}`#5243`)
* Bug fixes:
  * `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
    complex inputs ({jax-issue}`#5156`)
  * `host_callback.id_tap` now works for `jax.pmap` also. There is an
    optional parameter for `id_tap` and `id_print` to request that the
    device from which the value is tapped be passed as a keyword argument
    to the tap function ({jax-issue}`#5182`).
* Breaking changes:
  * `jax.numpy.pad` now takes keyword arguments. Positional argument `constant_values`
    has been removed. In addition, passing unsupported keyword arguments raises an error.
  * Changes for {func}`jax.experimental.host_callback.id_tap` ({jax-issue}`#5243`):
    * Removed support for `kwargs` for {func}`jax.experimental.host_callback.id_tap`.
      (This support has been deprecated for a few months.)
    * Changed the printing of tuples for {func}`jax.experimental.host_callback.id_print`
      to use '(' instead of '['.
    * Changed the {func}`jax.experimental.host_callback.id_print` in presence of JVP
      to print a pair of primal and tangent. Previously, there were two separate
      print operations for the primals and the tangent.
    * `host_callback.outfeed_receiver` has been removed (it is not necessary,
      and was deprecated a few months ago).
* New features:
  * New flag for debugging `inf`, analagous to that for `NaN` ({jax-issue}`#5224`).

## jax 0.2.7 (Dec 4 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
  * Add `jax.device_put_replicated`
  * Add multi-host support to `jax.experimental.sharded_jit`
  * Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
  * Add support for building on Windows platforms
  * Add support for general in_axes and out_axes in `jax.pmap`
  * Add complex support for `jax.numpy.linalg.slogdet`
* Bug fixes:
  * Fix higher-than-second order derivatives of `jax.numpy.sinc` at zero
  * Fix some hard-to-hit bugs around symbolic zeros in transpose rules
* Breaking changes:
  * `jax.experimental.optix` has been deleted, in favor of the standalone
    `optax` Python package.
  * indexing of JAX arrays with non-tuple sequences now raises a `TypeError`. This type of indexing
    has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
    See {jax-issue}`#4564`.

## jax 0.2.6 (Nov 18 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
  * Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
    See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup

  * Raise an error on non-hashable static arguments for jax.jit and
    xla_computation.  See [cb48f42](https://github.com/google/jax/commit/cb48f42).
  * Improve consistency of type promotion behavior ({jax-issue}`#4744`):
    * Adding a complex Python scalar to a JAX floating point number respects the precision of
      the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
      it returned `complex128`.
    * Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
      are now independent of the order of arguments. For example:
      `jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
      `jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
      the first returned `float64` and the second returned `float16`.
  * The contents of the (undocumented) `jax.lax_linalg` linear algebra module
    are now exposed publicly as `jax.lax.linalg`.
  * `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
    ({jax-issue}`#4877`).
    This required changing the result for a given seed in a few particular cases:
    * With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
      outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
      `[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
    * Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
      rather than a `TypeError`. This matches the behavior in JIT.

    To recover the keys returned previously for negative integers with `jax_enable_x64=False`
    outside JIT, you can use:

    ```
    key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
    ```
  * DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
    to access its value while it has been deleted.

## jaxlib 0.1.58 (January 12ish 2021)

* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
  `np.cint`) instead of standard types (e.g., `np.int32`). (#4903)
* Fixed a crash when constant-folding certain int16 operations. (#4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.

## jaxlib 0.1.57 (November 12 2020)

* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
  and could change (#4651).
* Add support for retaining ownership when passing arrays to DLPack (#4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
  a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
* Fixed a bug in profiler where tools are missing (#4427).
* Dropped support for CUDA 10.0.

## jax 0.2.5 (October 27 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
  * Ensure that `check_jaxpr` does not perform FLOPS.  See {jax-issue}`#4650`.
  * Expanded the set of JAX primitives converted by jax2tf.
    See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).

## jax 0.2.4 (October 19 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements:
  * Add support for `remat` to jax.experimental.host_callback.  See {jax-issue}`#4608`.
* Deprecations

  * Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
    In a future release, this will result in a TypeError. See {jax-issue}`#4564`.

## jaxlib 0.1.56 (October 14, 2020)

## jax 0.2.3 (October 14 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a
  new jit fastpath while we look into a performance degradation

## jax 0.2.2 (October 13 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2).

## jax 0.2.1 (October 6 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1).
* Improvements:
  * As a benefit of omnistaging, the host_callback functions are executed (in program
    order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
    {py:func}`jax.experimental.host_callback.id_tap` is not used in the computation.

## jax (0.2.0) (September 23 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0).
* Improvements:
  * Omnistaging on by default. See {jax-issue}`#3370` and
    [omnistaging](https://github.com/google/jax/blob/main/design_notes/omnistaging.md)

## jax (0.1.77) (September 15 2020)

* Breaking changes:
  * New simplified interface for {py:func}`jax.experimental.host_callback.id_tap` (#4101)

## jaxlib 0.1.55 (September 8, 2020)

* Update XLA:
  * Fix bug in DLPackManagedTensorToBuffer (#4196)

## jax 0.1.76 (September 8, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76).

## jax 0.1.75 (July 30, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75).
* Bug Fixes:
  * make jnp.abs() work for unsigned inputs (#3914)
* Improvements:
  * "Omnistaging" behavior added behind a flag, disabled by default (#3370)

## jax 0.1.74 (July 29, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
  * BFGS (#3101)
  * TPU support for half-precision arithmetic (#3878)
* Bug Fixes:
  * Prevent some accidental dtype warnings (#3874)
  * Fix a multi-threading bug in custom derivatives (#3845, #3869)
* Improvements:
  * Faster searchsorted implementation (#3873)
  * Better test coverage for jax.numpy sorting algorithms (#3836)

## jaxlib 0.1.52 (July 22, 2020)

* Update XLA.

## jax 0.1.73 (July 22, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73).
* The minimum jaxlib version is now 0.1.51.
* New Features:
  * jax.image.resize. (#3703)
  * hfft and ihfft (#3664)
  * jax.numpy.intersect1d (#3726)
  * jax.numpy.lexsort (#3812)
  * `lax.scan` and the `scan` primitive support an `unroll`
    parameter for loop unrolling when lowering to XLA
    ({jax-issue}`#3738`).
* Bug Fixes:
  * Fix reduction repeated axis error (#3618)
  * Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
  * make psum transpose handle zero cotangents (#3653)
  * Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
  * Support differentiation through jax.lax.all_to_all (#3733)
  * address nan issue in jax.scipy.special.zeta (#3777)
* Improvements:
  * Many improvements to jax2tf
  * Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
  * Enable XLA SPMD partitioning by default. (#3151)
  * Add support for 0d transpose convolution (#3643)
  * Make LU gradient work for low-rank matrices (#3610)
  * support multiple_results and custom JVPs in jet (#3657)
  * Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
  * Implement complex convolutions on CPU and GPU. (#3735)
  * Make jnp.take work for empty slices of empty arrays. (#3751)
  * Relax dimension ordering rules for dot_general. (#3778)
  * Enable buffer donation for GPU. (#3800)
  * Add support for base dilation and window dilation to reduce window op… (#3803)

## jaxlib 0.1.51 (July 2, 2020)

* Update XLA.
* Add new runtime support for host_callback.

## jax 0.1.72 (June 28, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72).
* Bug fixes:
  * Fix an odeint bug introduced in the previous release, see
    {jax-issue}`#3587`.

## jax 0.1.71 (June 25, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71).
* The minimum jaxlib version is now 0.1.48.
* Bug fixes:
  * Allow `jax.experimental.ode.odeint` dynamics functions to close over
    values with respect to which we're differentiating
    {jax-issue}`#3562`.

## jaxlib 0.1.50 (June 25, 2020)

* Add support for CUDA 11.0.
* Drop support for CUDA 9.2 (we only maintain support for the last four CUDA
  versions.)
* Update XLA.

## jaxlib 0.1.49 (June 19, 2020)

* Bug fixes:
  * Fix build issue that could result in slow compiles
    (<https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1>)

## jaxlib 0.1.48 (June 12, 2020)

* New features:
  * Adds support for fast traceback collection.
  * Adds preliminary support for on-device heap profiling.
  * Implements `np.nextafter` for `bfloat16` types.
  * Complex128 support for FFTs on CPU and GPU.
* Bugfixes:
  * Improved float64 `tanh` accuracy on GPU.
  * float64 scatters on GPU are much faster.
  * Complex matrix multiplication on CPU should be much faster.
  * Stable sorts on CPU should actually be stable now.
  * Concurrency bug fix in CPU backend.

## jax 0.1.70 (June 8, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70).
* New features:
  * `lax.switch` introduces indexed conditionals with multiple
    branches, together with a generalization of the `cond`
    primitive
    {jax-issue}`#3318`.

## jax 0.1.69 (June 3, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69).

## jax 0.1.68 (May 21, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68).
* New features:
  * {func}`lax.cond` supports a single-operand form, taken as the argument
    to both branches
    {jax-issue}`#2993`.
* Notable changes:
  * The format of the `transforms` keyword for the {func}`jax.experimental.host_callback.id_tap`
    primitive has changed {jax-issue}`#3132`.

## jax 0.1.67 (May 12, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67).
* New features:
  * Support for reduction over subsets of a pmapped axis using `axis_index_groups`
    {jax-issue}`#2382`.
  * Experimental support for printing and calling host-side Python function from
    compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)
    ({jax-issue}`#3006`).
* Notable changes:
  * The visibility of names exported from {mod}`jax.numpy` has been
    tightened. This may break code that was making use of names that were
    previously exported accidentally.

## jaxlib 0.1.47 (May 8, 2020)

* Fixes crash for outfeed.

## jax 0.1.66 (May 5, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66).
* New features:
  * Support for `in_axes=None` on {func}`pmap`
    {jax-issue}`#2896`.

## jaxlib 0.1.46 (May 5, 2020)

* Fixes crash for linear algebra functions on Mac OS X (#432).
* Fixes an illegal instruction crash caused by using AVX512 instructions when
  an operating system or hypervisor disabled them (#2906).

## jax 0.1.65 (April 30, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65).
* New features:
  * Differentiation of determinants of singular matrices
    {jax-issue}`#2809`.
* Bug fixes:
  * Fix {func}`odeint` differentiation with respect to time of ODEs with
    time-dependent dynamics {jax-issue}`#2817`,
    also add ODE CI testing.
  * Fix {func}`lax_linalg.qr` differentiation
    {jax-issue}`#2867`.

## jaxlib 0.1.45 (April 21, 2020)

* Fixes segfault: {jax-issue}`#2755`
* Plumb is_stable option on Sort HLO through to Python.

## jax 0.1.64 (April 21, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64).
* New features:
  * Add syntactic sugar for functional indexed updates
    {jax-issue}`#2684`.
  * Add {func}`jax.numpy.linalg.multi_dot` {jax-issue}`#2726`.
  * Add {func}`jax.numpy.unique` {jax-issue}`#2760`.
  * Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
  * Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
  * Add more primitive rules for {func}`jax.experimental.jet`.
* Bug fixes:
  * Fix {func}`logaddexp` and {func}`logaddexp2` differentiation at zero {jax-issue}`#2107`.
  * Improve memory usage in reverse-mode autodiff without {func}`jit`
    {jax-issue}`#2719`.
* Better errors:
  * Improves error message for reverse-mode differentiation of {func}`lax.while_loop`
    {jax-issue}`#2129`.

## jaxlib 0.1.44 (April 16, 2020)

* Fixes a bug where if multiple GPUs of different models were present, JAX
  would only compile programs suitable for the first GPU.
* Bugfix for `batch_group_count` convolutions.
* Added precompiled SASS for more GPU versions to avoid startup PTX compilation
  hang.

## jax 0.1.63 (April 12, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63).
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
* Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`.
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`.
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
* Add `callback_transform` {jax-issue}`#2665`.
* Implement `rollaxis`, `convolve`/`correlate` 1d & 2d, `copysign`,
  `trunc`, `roots`, and `quantile`/`percentile` interpolation options.

## jaxlib 0.1.43 (March 31, 2020)

* Fixed a performance regression for Resnet-50 on GPU.

## jax 0.1.62 (March 21, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62).
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the
  convention `0. * nan == 0.`. This change means some programs when
  differentiated will produce nans when they previously produced correct
  values, though it ensures nans rather than silently incorrect results are
  produced for other programs. See #2447 and #1052 for details.
* Added an `all_gather` parallel convenience function.
* More type annotations in core code.

## jaxlib 0.1.42 (March 19, 2020)

* jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This
  release fixes it again.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

## jax 0.1.61 (March 17, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61).
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
  supports Python 3.5.

## jax 0.1.60 (March 17, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60).
* New features:
  * {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
    the user to specify arguments that should be treated as compile-time
    constants and should be broadcasted to all devices. It works analogously to
    `static_argnums` in {py:func}`jax.jit`.
  * Improved error messages for when tracers are mistakenly saved in global state.
  * Added {py:func}`jax.nn.one_hot` utility function.
  * Added {mod}`jax.experimental.jet` for exponentially faster
    higher-order automatic differentiation.
  * Added more correctness checking to arguments of {py:func}`jax.lax.broadcast_in_dim`.
* The minimum jaxlib version is now 0.1.41.

## jaxlib 0.1.40 (March 4, 2020)

* Adds experimental support in Jaxlib for TensorFlow profiler, which allows
  tracing of CPU and GPU computations from TensorBoard.
* Includes prototype support for multihost GPU computations that communicate via
  NCCL.
* Improves performance of NCCL collectives on GPU.
* Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and
  RandomGamma implementations.
* Supports device assignments known at XLA compilation time.

## jax 0.1.59 (February 11, 2020)

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59).
* Breaking changes

  * The minimum jaxlib version is now 0.1.38.
  * Simplified {py:class}`Jaxpr` by removing the `Jaxpr.freevars` and
    `Jaxpr.bound_subjaxprs`. The call primitives (`xla_call`, `xla_pmap`,
    `sharded_call`, and `remat_call`) get a new parameter `call_jaxpr` with a
    fully-closed (no `constvars`) jaxpr. Also, added a new field `call_primitive`
    to primitives.
* New features:
  * Reverse-mode automatic differentiation (e.g. `grad`) of `lax.cond`, making it
    now differentiable in both modes ({jax-issue}`#2091`)
  * JAX now supports DLPack, which allows sharing CPU and GPU arrays in a
    zero-copy way with other libraries, such as PyTorch.
  * JAX GPU DeviceArrays now support `__cuda_array_interface__`, which is another
    zero-copy protocol for sharing GPU arrays with other libraries such as CuPy
    and Numba.
  * JAX CPU device buffers now implement the Python buffer protocol, which allows
    zero-copy buffer sharing between JAX and NumPy.
  * Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.

## jaxlib 0.1.39 (February 11, 2020)

* Updates XLA.

## jaxlib 0.1.38 (January 29, 2020)

* CUDA 9.0 is no longer supported.
* CUDA 10.2 wheels are now built by default.

## jax 0.1.58 (January 28, 2020)

* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58).
* Breaking changes

  * JAX has dropped Python 2 support, because Python 2 reached its end of life on
    January 1, 2020. Please update to Python 3.5 or newer.
* New features

  >   > * Forward-mode automatic differentiation (`jvp`) of while loop
  >   ({jax-issue}`#1980`)
  > * New NumPy and SciPy functions:
  >
  >   * {py:func}`jax.numpy.fft.fft2`
  >   * {py:func}`jax.numpy.fft.ifft2`
  >   * {py:func}`jax.numpy.fft.rfft`
  >   * {py:func}`jax.numpy.fft.irfft`
  >   * {py:func}`jax.numpy.fft.rfft2`
  >   * {py:func}`jax.numpy.fft.irfft2`
  >   * {py:func}`jax.numpy.fft.rfftn`
  >   * {py:func}`jax.numpy.fft.irfftn`
  >   * {py:func}`jax.numpy.fft.fftfreq`
  >   * {py:func}`jax.numpy.fft.rfftfreq`
  >   * {py:func}`jax.numpy.linalg.matrix_rank`
  >   * {py:func}`jax.numpy.linalg.matrix_power`
  >   * {py:func}`jax.scipy.special.betainc`
  > * Batched Cholesky decomposition on GPU now uses a more efficient batched
  >   kernel.

### Notable bug fixes

* With the Python 3 upgrade, JAX no longer depends on `fastcache`, which should
  help with installation.
back to top