https://github.com/google/jax
Raw File
Tip revision: a9a72fa53b4eac1212c5d33388fb140bba105047 authored by jax authors on 10 June 2024, 20:39:35 UTC
Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
Tip revision: a9a72fa
jax.rst
.. currentmodule:: jax

Public API: ``jax`` package
===========================

Subpackages
-----------

.. toctree::
   :maxdepth: 1

   jax.numpy
   jax.scipy
   jax.lax
   jax.random
   jax.sharding
   jax.debug
   jax.dlpack
   jax.distributed
   jax.dtypes
   jax.flatten_util
   jax.image
   jax.nn
   jax.ops
   jax.profiler
   jax.stages
   jax.tree
   jax.tree_util
   jax.typing
   jax.export
   jax.extend
   jax.example_libraries
   jax.experimental

.. toctree::
   :hidden:

   jax.lib

Configuration
-------------

.. autosummary::
   :toctree: _autosummary

   config
   check_tracer_leaks
   checking_leaks
   debug_nans
   debug_infs
   default_device
   default_matmul_precision
   default_prng_impl
   enable_checks
   enable_custom_prng
   enable_custom_vjp_by_custom_transpose
   log_compiles
   numpy_rank_promotion
   transfer_guard

.. _jax-jit:

Just-in-time compilation (:code:`jit`)
--------------------------------------

.. autosummary::
  :toctree: _autosummary

    jit
    disable_jit
    ensure_compile_time_eval
    xla_computation
    make_jaxpr
    eval_shape
    ShapeDtypeStruct
    device_put
    device_put_replicated
    device_put_sharded
    device_get
    default_backend
    named_call
    named_scope
    block_until_ready

.. _jax-grad:

Automatic differentiation
-------------------------

.. autosummary::
  :toctree: _autosummary

    grad
    value_and_grad
    jacfwd
    jacrev
    hessian
    jvp
    linearize
    linear_transpose
    vjp
    custom_gradient
    closure_convert
    checkpoint

``custom_jvp``
~~~~~~~~~~~~~~

.. autosummary::
  :toctree: _autosummary

  custom_jvp
  custom_jvp.defjvp
  custom_jvp.defjvps

``custom_vjp``
~~~~~~~~~~~~~~

.. autosummary::
  :toctree: _autosummary

  custom_vjp
  custom_vjp.defvjp

jax.Array (:code:`jax.Array`)
-----------------------------

.. autosummary::
  :toctree: _autosummary

    Array
    make_array_from_callback
    make_array_from_single_device_arrays
    make_array_from_process_local_data

Array properties and methods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
  :toctree: _autosummary

    Array.addressable_shards
    Array.all
    Array.any
    Array.argmax
    Array.argmin
    Array.argpartition
    Array.argsort
    Array.astype
    Array.at
    Array.choose
    Array.clip
    Array.compress
    Array.conj
    Array.conjugate
    Array.copy
    Array.copy_to_host_async
    Array.cumprod
    Array.cumsum
    Array.device
    Array.diagonal
    Array.dot
    Array.dtype
    Array.flat
    Array.flatten
    Array.global_shards
    Array.imag
    Array.is_fully_addressable
    Array.is_fully_replicated
    Array.item
    Array.itemsize
    Array.max
    Array.mean
    Array.min
    Array.nbytes
    Array.ndim
    Array.nonzero
    Array.prod
    Array.ptp
    Array.ravel
    Array.real
    Array.repeat
    Array.reshape
    Array.round
    Array.searchsorted
    Array.shape
    Array.sharding
    Array.size
    Array.sort
    Array.squeeze
    Array.std
    Array.sum
    Array.swapaxes
    Array.take
    Array.to_device
    Array.trace
    Array.transpose
    Array.var
    Array.view
    Array.T
    Array.mT

Vectorization (:code:`vmap`)
----------------------------

.. autosummary::
  :toctree: _autosummary

    vmap
    numpy.vectorize

Parallelization (:code:`pmap`)
------------------------------

.. autosummary::
  :toctree: _autosummary

    pmap
    devices
    local_devices
    process_index
    device_count
    local_device_count
    process_count
    process_indices

Callbacks
---------

.. autosummary::
  :toctree: _autosummary

    pure_callback
    experimental.io_callback
    debug.callback
    debug.print

Miscellaneous
-------------

.. autosummary::
  :toctree: _autosummary

    Device
    print_environment_info
    live_arrays
    clear_caches
back to top