Revision f309b82a20358a6a6560b87f8325f73ab06c3123 authored by Jevin Jiang on 05 August 2024, 18:06:42 UTC, committed by jax authors on 06 August 2024, 06:17:11 UTC
we also emulate shuffled store using (store + shuffled load + store) for previous generations.

PiperOrigin-RevId: 659612665
1 parent f255fb7
Raw File
array_api_test.py
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Smoketest for JAX's array API.

The full test suite for the array API is run via the array-api-tests CI;
this is just a minimal smoke test to catch issues early.
"""
from __future__ import annotations

from types import ModuleType

from absl.testing import absltest, parameterized
import jax
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype

ARRAY_API_NAMESPACE = jnp

config.parse_flags_with_absl()

MAIN_NAMESPACE = {
  'abs',
  'acos',
  'acosh',
  'add',
  'all',
  'any',
  'arange',
  'argmax',
  'argmin',
  'argsort',
  'asarray',
  'asin',
  'asinh',
  'astype',
  'atan',
  'atan2',
  'atanh',
  'bitwise_and',
  'bitwise_invert',
  'bitwise_left_shift',
  'bitwise_or',
  'bitwise_right_shift',
  'bitwise_xor',
  'bool',
  'broadcast_arrays',
  'broadcast_to',
  'can_cast',
  'ceil',
  'clip',
  'complex128',
  'complex64',
  'concat',
  'conj',
  'copysign',
  'cos',
  'cosh',
  'cumulative_sum',
  'divide',
  'e',
  'empty',
  'empty_like',
  'equal',
  'exp',
  'expand_dims',
  'expm1',
  'eye',
  'fft',
  'finfo',
  'flip',
  'float32',
  'float64',
  'floor',
  'floor_divide',
  'from_dlpack',
  'full',
  'full_like',
  'greater',
  'greater_equal',
  'hypot',
  'iinfo',
  'imag',
  'inf',
  'int16',
  'int32',
  'int64',
  'int8',
  'isdtype',
  'isfinite',
  'isinf',
  'isnan',
  'less',
  'less_equal',
  'linalg',
  'linspace',
  'log',
  'log10',
  'log1p',
  'log2',
  'logaddexp',
  'logical_and',
  'logical_not',
  'logical_or',
  'logical_xor',
  'matmul',
  'matrix_transpose',
  'max',
  'maximum',
  'mean',
  'meshgrid',
  'min',
  'minimum',
  'moveaxis',
  'multiply',
  'nan',
  'negative',
  'newaxis',
  'nonzero',
  'not_equal',
  'ones',
  'ones_like',
  'permute_dims',
  'pi',
  'positive',
  'pow',
  'prod',
  'real',
  'remainder',
  'repeat',
  'reshape',
  'result_type',
  'roll',
  'round',
  'searchsorted',
  'sign',
  'signbit',
  'sin',
  'sinh',
  'sort',
  'sqrt',
  'square',
  'squeeze',
  'stack',
  'std',
  'subtract',
  'sum',
  'take',
  'tan',
  'tanh',
  'tensordot',
  'tile',
  'tril',
  'triu',
  'trunc',
  'uint16',
  'uint32',
  'uint64',
  'uint8',
  'unique_all',
  'unique_counts',
  'unique_inverse',
  'unique_values',
  'unstack',
  'var',
  'vecdot',
  'where',
  'zeros',
  'zeros_like',
}

LINALG_NAMESPACE = {
  'cholesky',
  'cross',
  'det',
  'diagonal',
  'eigh',
  'eigvalsh',
  'inv',
  'matmul',
  'matrix_norm',
  'matrix_power',
  'matrix_rank',
  'matrix_transpose',
  'outer',
  'pinv',
  'qr',
  'slogdet',
  'solve',
  'svd',
  'svdvals',
  'tensordot',
  'trace',
  'vecdot',
  'vector_norm',
}

FFT_NAMESPACE = {
  'fft',
  'fftfreq',
  'fftn',
  'fftshift',
  'hfft',
  'ifft',
  'ifftn',
  'ifftshift',
  'ihfft',
  'irfft',
  'irfftn',
  'rfft',
  'rfftfreq',
  'rfftn',
}


def names(module: ModuleType) -> set[str]:
  return {name for name in dir(module) if not name.startswith('_')}


class ArrayAPISmokeTest(absltest.TestCase):
  """Smoke test for the array API."""

  def test_main_namespace(self):
    self.assertContainsSubset(MAIN_NAMESPACE, names(ARRAY_API_NAMESPACE))

  def test_linalg_namespace(self):
    self.assertContainsSubset(LINALG_NAMESPACE, names(ARRAY_API_NAMESPACE.linalg))

  def test_fft_namespace(self):
    self.assertContainsSubset(FFT_NAMESPACE, names(ARRAY_API_NAMESPACE.fft))

  def test_array_namespace_method(self):
    x = ARRAY_API_NAMESPACE.arange(20)
    self.assertIsInstance(x, jax.Array)
    self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)

  def test_deprecated_import(self):
    msg = "jax.experimental.array_api import is no longer required"
    with self.assertWarnsRegex(DeprecationWarning, msg):
      import jax.experimental.array_api as nx
    self.assertIs(nx, ARRAY_API_NAMESPACE)


class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

  info = ARRAY_API_NAMESPACE.__array_namespace_info__()

  def setUp(self):
    super().setUp()
    self._boolean = self.build_dtype_dict(["bool"])
    self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
    self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
    self._floating = self.build_dtype_dict(["float32"])
    self._complex = self.build_dtype_dict(["complex64"])
    if config.enable_x64.value:
      self._signed["int64"] = jnp.dtype("int64")
      self._unsigned["uint64"] = jnp.dtype("uint64")
      self._floating["float64"] = jnp.dtype("float64")
      self._complex["complex128"] = jnp.dtype("complex128")
    self._integral = self._signed | self._unsigned
    self._numeric = (
      self._signed | self._unsigned | self._floating | self._complex
    )
  def build_dtype_dict(self, dtypes):
    out = {}
    for name in dtypes:
        out[name] = jnp.dtype(name)
    return out

  def test_capabilities_info(self):
    capabilities = self.info.capabilities()
    assert capabilities["boolean indexing"]
    assert not capabilities["data-dependent shapes"]

  def test_default_device_info(self):
    assert self.info.default_device() is None

  def test_devices_info(self):
    assert self.info.devices() == jax.devices()

  def test_default_dtypes_info(self):
    _default_dtypes = {
      "real floating": "f",
      "complex floating": "c",
      "integral": "i",
      "indexing": "i",
    }
    target_dict = {
      dtype_name: canonicalize_dtype(
        _default_types.get(kind)
      ) for dtype_name, kind in _default_dtypes.items()
    }
    assert self.info.default_dtypes() == target_dict

  @parameterized.parameters(
    "bool", "signed integer", "real floating",
    "complex floating", "integral", "numeric", None,
    (("real floating", "complex floating"),),
    (("integral", "signed integer"),),
    (("integral", "bool"),),
  )
  def test_dtypes_info(self, kind):

    info_dict = self.info.dtypes(kind=kind)
    control = {
      "bool":self._boolean,
      "signed integer":self._signed,
      "unsigned integer":self._unsigned,
      "real floating":self._floating,
      "complex floating":self._complex,
      "integral": self._integral,
      "numeric": self._numeric
    }
    target_dict = {}
    if kind is None:
      target_dict = control["numeric"] | self._boolean
    elif isinstance(kind, tuple):
      target_dict = {}
      for _kind in kind:
        target_dict |= control[_kind]
    else:
      target_dict = control[kind]
    assert info_dict == target_dict

class ArrayAPIErrors(absltest.TestCase):
  """Test that our array API implementations raise errors where required"""

  # TODO(micky774): Remove when jnp.clip deprecation is completed
  # (began 2024-4-2) and default behavior is Array API 2023 compliant
  def test_clip_complex(self):
    x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.complex64)
    complex_msg = "Complex values have no ordering and cannot be clipped"
    with self.assertRaisesRegex(ValueError, complex_msg):
      ARRAY_API_NAMESPACE.clip(x)

    with self.assertRaisesRegex(ValueError, complex_msg):
      ARRAY_API_NAMESPACE.clip(x, max=x)

    x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.int32)
    with self.assertRaisesRegex(ValueError, complex_msg):
      ARRAY_API_NAMESPACE.clip(x, min=-1+5j)

    with self.assertRaisesRegex(ValueError, complex_msg):
      ARRAY_API_NAMESPACE.clip(x, max=-1+5j)


if __name__ == '__main__':
  absltest.main()
back to top