https://github.com/google/jax
Raw File
Tip revision: 7495f71ea42db7b985427a12a1d366ed3cffda64 authored by jax authors on 03 February 2021, 17:15:38 UTC
Merge pull request #5614 from hawkinsp:jaxlib
Tip revision: 7495f71
jax.ops.rst

jax.ops package
=================

.. currentmodule:: jax.ops

.. automodule:: jax.ops


Indexed update operators
------------------------

JAX is intended to be used with a functional style of programming, and hence
does not support NumPy-style indexed assignment directly. Instead, JAX provides
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.

.. autosummary::
  :toctree: _autosummary

    index
    index_update
    index_add
    index_mul
    index_min
    index_max


Syntactic sugar for indexed update operators
--------------------------------------------

JAX also provides an alternate syntax for these indexed update operators.
Specifically, JAX ndarray types have a property ``at``, which can be used as
follows (where ``idx`` can be an arbitrary index expression).

====================  ===================================================
Alternate syntax      Equivalent expression
====================  ===================================================
``x.at[idx].set(y)``  ``jax.ops.index_update(x, jax.ops.index[idx], y)``
``x.at[idx].add(y)``  ``jax.ops.index_add(x, jax.ops.index[idx], y)``
``x.at[idx].mul(y)``  ``jax.ops.index_mul(x, jax.ops.index[idx], y)``
``x.at[idx].min(y)``  ``jax.ops.index_min(x, jax.ops.index[idx], y)``
``x.at[idx].max(y)``  ``jax.ops.index_max(x, jax.ops.index[idx], y)``
====================  ===================================================

Note that none of these expressions modify the original `x`; instead they return
a modified copy of `x`.

Other operators
---------------

.. autosummary::
  :toctree: _autosummary

    segment_sum
back to top