https://github.com/google/jax
Raw File
Tip revision: f1f4840a0d7af7b206c50874d877627cdeddb004 authored by jax authors on 09 March 2023, 18:59:33 UTC
Merge pull request #14881 from skye:version
Tip revision: f1f4840
typing.py
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The JAX typing module is where JAX-specific static type annotations live.
This submodule is a work in progress; to see the proposal behind the types exported
here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.

The currently-available types are:

- :class:`jax.Array`: annotation for any JAX array or tracer (i.e. representations of arrays
  within JAX transforms).
- :class:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to
  a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python
  builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values
  (e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.)

We may add additional types here in future releases.

JAX Typing Best Practices
-------------------------
When annotating JAX arrays in public API functions, we recommend using :class:`~jax.typing.ArrayLike`
for array inputs, and :class:`~jax.Array` for array outputs.

For example, your function might look like this::

    import numpy as np
    import jax.numpy as jnp
    from jax import Array
    from jax.typing import ArrayLike

    def my_function(x: ArrayLike) -> Array:
      # Runtime type validation:
      if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
        raise TypeError(f"Expected arraylike input; got {x}")

      # Convert input to jax.Array:
      x_arr = jnp.asarray(x)

      # ... do some computation; JAX functions will return Array types:
      result = x_arr.sum(0) / x_arr.shape[0]

      # return an Array
      return result

Most of JAX's public APIs follow this pattern. Note in particular that we recommend JAX functions
to not accept sequences such as :class:`list` or :class:`tuple` in place of arrays, as this can
cause extra overhead in JAX transforms like :func:`~jax.jit` and can behave in unexpected ways with
batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this,
see `Non-array inputs NumPy vs JAX`_

.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax
"""
from jax._src.typing import (
    ArrayLike as ArrayLike
)
back to top