https://github.com/google/jax
Raw File
Tip revision: d19a0ad3cfdca9ff334e0873e82b622b115b912a authored by Sandeep Dasgupta on 13 June 2022, 22:08:54 UTC
[mhlo] Unify different versions of round op (viz., mhlo::round_nearest_afz & mhlo::round_nearest_even) into one.
Tip revision: d19a0ad
concurrency.rst
Concurrency
===========

JAX has limited support for Python concurrency.

Clients may call JAX APIs (e.g., :func:`~jax.jit` or :func:`~jax.grad`)
concurrently from separate Python threads.

It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
tracing (e.g., :func:`~jax.jit`) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
`f` that is passed to :func:`~jax.jit`. The most likely outcome if you do this
is a mysterious error from JAX.
back to top