https://github.com/google/jax
Raw File
Tip revision: 25b4070268448d1cd75220378b3db5c60551ddae authored by jax authors on 26 October 2020, 22:17:45 UTC
Merge pull request #4707 from google:traceback-register
Tip revision: 25b4070
concurrency.rst
Concurrency
===========

JAX has some limited support for Python concurrency.

Concurrency support is experimental and only lightly tested; please report any
bugs.

Clients may call JAX APIs (e.g., :func:`~jax.jit` or :func:`~jax.grad`)
concurrently from separate Python threads.

It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
tracing (e.g., :func:`~jax.jit`) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
`f` that is passed to :func:`~jax.jit`. The most likely outcome if you do this
is a mysterious error from JAX.
back to top