https://github.com/google/jax
Raw File
Tip revision: fa758bcff564ba7ac99fc13df0185f49a899aafc authored by jax authors on 09 July 2021, 20:06:13 UTC
Merge pull request #7233 from zhangqiaorjc:update_xla_2
Tip revision: fa758bc
jax.numpy.rst

jax.numpy package
=================

.. currentmodule:: jax.numpy

.. automodule:: jax.numpy

Implements the NumPy API, using the primitives in :mod:`jax.lax`.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
cannot follow NumPy exactly.

* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
  in-place cannot be implemented in JAX. However, often JAX is able to provide a
  alternative API that is purely functional. For example, instead of in-place
  array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
  update function :func:`jax.ops.index_update`.

* Relatedly, some NumPy functions return views of arrays when possible (examples
  are :func:`numpy.transpose` and :func:`numpy.reshape`). JAX versions of such
  functions will return copies instead, although such copies can often be optimized
  away by XLA when sequences of operations are compiled using :func:`jax.jit`.

* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
  sometimes is less aggressive about type promotion (See :ref:`type-promotion`).

A small number of NumPy operations that have data-dependent output shapes are
incompatible with :func:`jax.jit` compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable
to JIT-compile it because the shape of its output depends on the contents of the
input data.

Not every function in NumPy is implemented; contributions are welcome!

.. Generate the list below as follows:
   >>> import jax.numpy, numpy
   >>> fns = set(dir(numpy)) & set(dir(jax.numpy)) - set(jax.numpy._NOT_IMPLEMENTED)
   >>> print('\n'.join('    ' + x for x in fns if callable(getattr(jax.numpy, x))))  # doctest: +SKIP

   # Finally, sort the list using sort(1), which is different than Python's
   # sorted() function.

.. autosummary::
  :toctree: _autosummary

    abs
    absolute
    add
    all
    allclose
    alltrue
    amax
    amin
    angle
    any
    append
    apply_along_axis
    apply_over_axes
    arange
    arccos
    arccosh
    arcsin
    arcsinh
    arctan
    arctan2
    arctanh
    argmax
    argmin
    argsort
    argwhere
    around
    array
    array_equal
    array_equiv
    array_repr
    array_split
    array_str
    asarray
    atleast_1d
    atleast_2d
    atleast_3d
    average
    bartlett
    bincount
    bitwise_and
    bitwise_not
    bitwise_or
    bitwise_xor
    blackman
    block
    bool_
    broadcast_arrays
    broadcast_shapes
    broadcast_to
    c_
    can_cast
    cbrt
    cdouble
    ceil
    character
    choose
    clip
    column_stack
    complex_
    complex128
    complex64
    complexfloating
    ComplexWarning
    compress
    concatenate
    conj
    conjugate
    convolve
    copysign
    corrcoef
    correlate
    cos
    cosh
    count_nonzero
    cov
    cross
    csingle
    cumprod
    cumproduct
    cumsum
    deg2rad
    degrees
    delete
    diag
    diagflat
    diag_indices
    diag_indices_from
    diagonal
    diff
    digitize
    divide
    divmod
    dot
    double
    dsplit
    dstack
    dtype
    ediff1d
    einsum
    einsum_path
    empty
    empty_like
    equal
    exp
    exp2
    expand_dims
    expm1
    extract
    eye
    fabs
    finfo
    fix
    flatnonzero
    flexible
    flip
    fliplr
    flipud
    float_
    float16
    float32
    float64
    floating
    float_power
    floor
    floor_divide
    fmax
    fmin
    fmod
    frexp
    full
    full_like
    gcd
    geomspace
    gradient
    greater
    greater_equal
    hamming
    hanning
    heaviside
    histogram
    histogram_bin_edges
    histogram2d
    histogramdd
    hsplit
    hstack
    hypot
    i0
    identity
    iinfo
    imag
    in1d
    indices
    inexact
    inner
    int_
    int16
    int32
    int64
    int8
    integer
    interp
    intersect1d
    invert
    isclose
    iscomplex
    iscomplexobj
    isfinite
    isin
    isinf
    isnan
    isneginf
    isposinf
    isreal
    isrealobj
    isscalar
    issubdtype
    issubsctype
    iterable
    ix_
    kaiser
    kron
    lcm
    ldexp
    left_shift
    less
    less_equal
    lexsort
    linspace
    load
    log
    log10
    log1p
    log2
    logaddexp
    logaddexp2
    logical_and
    logical_not
    logical_or
    logical_xor
    logspace
    mask_indices
    matmul
    max
    maximum
    mean
    median
    meshgrid
    mgrid
    min
    minimum
    mod
    modf
    moveaxis
    msort
    multiply
    nanargmax
    nanargmin
    nancumprod
    nancumsum
    nanmax
    nanmean
    nanmedian
    nanmin
    nanpercentile
    nanprod
    nanquantile
    nanstd
    nansum
    nan_to_num
    nanvar
    ndarray
    ndim
    negative
    nextafter
    nonzero
    not_equal
    number
    object_
    ogrid
    ones
    ones_like
    outer
    packbits
    pad
    percentile
    piecewise
    poly
    polyadd
    polyder
    polyint
    polymul
    polysub
    polyval
    positive
    power
    prod
    product
    promote_types
    ptp
    quantile
    r_
    rad2deg
    radians
    ravel
    ravel_multi_index
    real
    reciprocal
    remainder
    repeat
    reshape
    resize
    result_type
    right_shift
    rint
    roll
    rollaxis
    roots
    rot90
    round
    round_
    row_stack
    save
    savez
    searchsorted
    select
    set_printoptions
    setdiff1d
    setxor1d
    shape
    sign
    signbit
    signedinteger
    sin
    sinc
    single
    sinh
    size
    sometrue
    sort
    sort_complex
    split
    sqrt
    square
    squeeze
    stack
    std
    subtract
    sum
    swapaxes
    take
    take_along_axis
    tan
    tanh
    tensordot
    tile
    trace
    transpose
    trapz
    tri
    tril
    tril_indices
    tril_indices_from
    trim_zeros
    triu
    triu_indices
    triu_indices_from
    true_divide
    trunc
    uint16
    uint32
    uint64
    uint8
    unique
    union1d
    unpackbits
    unravel_index
    unsignedinteger
    unwrap
    vander
    var
    vdot
    vectorize
    vsplit
    vstack
    where
    zeros
    zeros_like

jax.numpy.fft
-------------

.. automodule:: jax.numpy.fft

.. autosummary::
  :toctree: _autosummary

  fft
  fft2
  fftfreq
  fftn
  fftshift
  hfft
  ifft
  ifft2
  ifftn
  ifftshift
  ihfft
  irfft
  irfft2
  irfftn
  rfft
  rfft2
  rfftfreq
  rfftn

jax.numpy.linalg
----------------

.. automodule:: jax.numpy.linalg

.. autosummary::
  :toctree: _autosummary

  cholesky
  cond
  det
  eig
  eigh
  eigvals
  eigvalsh
  inv
  lstsq
  matrix_power
  matrix_rank
  multi_dot
  norm
  pinv
  qr
  slogdet
  solve
  svd
  tensorinv
  tensorsolve

JAX DeviceArray
---------------
The JAX :class:`~jax.numpy.DeviceArray` is the core array object in JAX: you can
think of it as the equivalent of a :class:`numpy.ndarray` backed by a memory buffer
on a single device. Like :class:`numpy.ndarray`, most users will not need to
instantiate :class:`DeviceArray` objects manually, but rather will create them via
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
:func:`~jax.numpy.linspace`, and others listed above.

.. autoclass:: jax.numpy.DeviceArray

.. autoclass:: jaxlib.xla_extension.DeviceArrayBase

.. autoclass:: jaxlib.xla_extension.DeviceArray
   :members:
   :inherited-members:
back to top