Raw File
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from functools import partial
import itertools
import operator
import random
import unittest
from typing import NamedTuple, Tuple

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.random
from jax import config
from jax import dtypes
from jax.experimental import sparse
from jax.experimental.sparse.bcoo import BCOOInfo
from jax import lax
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax import jit
from jax import tree_util
from jax import vmap
from jax._src import test_util as jtu
from jax._src.lax.lax import remaining, DotDimensionNumbers
from jax import xla
import jax.numpy as jnp
from jax.util import split_list
import numpy as np
import scipy.sparse
config.parse_flags_with_absl()
FLAGS = config.FLAGS

MATMUL_TOL = {
  np.float32: 1E-5,
  np.float64: 1E-10,
  np.complex64: 1e-5,
  np.complex128: 1E-10,
}

class BcooDotGeneralProperties(NamedTuple):
  lhs_shape: Tuple[int]
  rhs_shape: Tuple[int]
  dtype: np.dtype
  n_batch: int
  n_dense: int
  dimension_numbers: DotDimensionNumbers

  def testcase_name(self):
    return "_{}_{}_nbatch={}_ndense={}_dimension_numbers={}".format(
      jtu.format_shape_dtype_string(self.lhs_shape, self.dtype),
      jtu.format_shape_dtype_string(self.rhs_shape, self.dtype),
      self.n_batch, self.n_dense, self.dimension_numbers)


def _iter_subsets(s):
  return itertools.chain.from_iterable(itertools.combinations(s, n) for n in range(len(s) + 1))

def _generate_bcoo_dot_general_properties(shapes, dtypes) -> BcooDotGeneralProperties:
  """Generator of properties for bcoo_dot_general tests."""
  rng = random.Random(0)

  for shape in shapes:
    for n_batch in range(len(shape) + 1):
      for n_dense in range(len(shape) + 1 - n_batch):
        n_sparse = len(shape) - n_batch - n_dense
        subsets = split_list(range(len(shape)), [n_batch, n_sparse])
        for batch_dims in _iter_subsets(range(n_batch)):
          for contracting_dims in _iter_subsets(remaining(range(n_batch + n_sparse), batch_dims)):
            # We want coverage of permutations & dtypes without generating hundreds of thousands
            # of test cases; we do this by deterministic pseudo-random sampling instead of iterating.
            rhs_permute = rng.sample(range(len(shape)), len(shape))
            lhs_permute = list(itertools.chain.from_iterable(
              rng.sample(subset, len(subset)) for subset in subsets))
            yield BcooDotGeneralProperties(
              lhs_shape=tuple(shape[p] for p in lhs_permute),
              rhs_shape=tuple(shape[p] for p in rhs_permute),
              dtype=rng.choice(dtypes),
              n_batch=n_batch,
              n_dense=n_dense,
              dimension_numbers=(
                ([lhs_permute.index(d) for d in contracting_dims], [rhs_permute.index(d) for d in contracting_dims]),
                ([lhs_permute.index(d) for d in batch_dims], [rhs_permute.index(d) for d in batch_dims])
              ),
            )


all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex


def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
  def _rand_sparse(shape, dtype, nse=nse):
    rand = rand_method(rng)
    size = np.prod(shape).astype(int)
    if 0 <= nse < 1:
      nse = nse * size
    nse = min(size, int(nse))
    M = rand(shape, dtype)
    indices = rng.choice(size, size - nse, replace=False)
    M.flat[indices] = 0
    return post(M)
  return _rand_sparse


class cuSparseTest(jtu.JaxTestCase):
  def gpu_dense_conversion_warning_context(self, dtype):
    if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
      return self.assertWarns(sparse.CuSparseEfficiencyWarning)
    return contextlib.nullcontext()

  def gpu_matmul_warning_context(self, dtype):
    if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
      return self.assertWarns(sparse.CuSparseEfficiencyWarning)
    return contextlib.nullcontext()

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes))
  def test_csr_todense(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
    M = rng(shape, dtype)

    args = (M.data, M.indices, M.indptr)
    todense = lambda *args: sparse.csr_todense(*args, shape=M.shape)

    self.assertArraysEqual(M.toarray(), todense(*args))
    with self.gpu_dense_conversion_warning_context(dtype):
      self.assertArraysEqual(M.toarray(), jit(todense)(*args))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_csr_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, indices, indptr = sparse.csr_fromdense(M, nse=(M != 0).sum())
    row, col = sparse.util._csr_to_coo(indices, indptr)
    f = lambda data: sparse.csr_todense(data, indices, indptr, shape=M.shape)

    # Forward-mode
    primals, tangents = jax.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = jax.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_csr_fromdense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    nse = (M != 0).sum()
    f = lambda M: sparse.csr_fromdense(M, nse=nse)

    # Forward-mode
    primals, tangents = jax.jvp(f, [M], [jnp.ones_like(M)])
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(tangents[0], jnp.ones(nse, dtype=dtype))
    self.assertEqual(tangents[1].dtype, dtypes.float0)
    self.assertEqual(tangents[2].dtype, dtypes.float0)

    # Reverse-mode
    primals, vjp_fun = jax.vjp(f, M)
    M_out, = vjp_fun(primals)
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(M_out, M)

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
        jtu.format_shape_dtype_string(shape, dtype),
        jtu.format_shape_dtype_string(bshape, dtype)),
       "shape": shape, "dtype": dtype, "bshape": bshape}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_csr_matmul_ad(self, shape, dtype, bshape):
    csr_matmul = sparse.csr_matvec if len(bshape) == 1 else sparse.csr_matmat
    tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12}

    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())

    M = rng(shape, dtype)
    data, indices, indptr = sparse.csr_fromdense(M, nse=(M != 0).sum())
    x = rng_b(bshape, dtype)
    xdot = rng_b(bshape, dtype)

    # Forward-mode with respect to the vector
    f_dense = lambda x: M @ x
    f_sparse = lambda x: csr_matmul(data, indices, indptr, x, shape=M.shape)
    v_sparse, t_sparse = jax.jvp(f_sparse, [x], [xdot])
    v_dense, t_dense = jax.jvp(f_dense, [x], [xdot])
    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to the vector
    primals_dense, vjp_dense = jax.vjp(f_dense, x)
    primals_sparse, vjp_sparse = jax.vjp(f_sparse, x)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

    # Forward-mode with respect to nonzero elements of the matrix
    f_sparse = lambda data: csr_matmul(data, indices, indptr, x, shape=M.shape)
    f_dense = lambda data: sparse.csr_todense(data, indices, indptr, shape=M.shape) @ x
    data = rng((len(data),), data.dtype)
    data_dot = rng((len(data),), data.dtype)
    v_sparse, t_sparse = jax.jvp(f_sparse, [data], [data_dot])
    v_dense, t_dense = jax.jvp(f_dense, [data], [data_dot])

    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to nonzero elements of the matrix
    primals_dense, vjp_dense = jax.vjp(f_dense, data)
    primals_sparse, vjp_sparse = jax.vjp(f_sparse, data)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes))
  def test_csr_fromdense(self, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_csr = scipy.sparse.csr_matrix(M)

    nse = M_csr.nnz
    index_dtype = jnp.int32
    fromdense = lambda M: sparse.csr_fromdense(M, nse=nse, index_dtype=jnp.int32)

    data, indices, indptr = fromdense(M)
    self.assertArraysEqual(data, M_csr.data.astype(dtype))
    self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
    self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

    with self.gpu_dense_conversion_warning_context(dtype):
      data, indices, indptr = jit(fromdense)(M)
    self.assertArraysEqual(data, M_csr.data.astype(dtype))
    self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
    self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes
      for transpose in [True, False]))
  def test_csr_matvec(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    v_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
    M = rng(shape, dtype)
    v = v_rng(op(M).shape[1], dtype)

    args = (M.data, M.indices, M.indptr, v)
    matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)

    self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
    with self.gpu_matmul_warning_context(dtype):
      self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes
      for transpose in [True, False]))
  def test_csr_matmat(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    B_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
    M = rng(shape, dtype)
    B = B_rng((op(M).shape[1], 4), dtype)

    args = (M.data, M.indices, M.indptr, B)
    matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)

    self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
    with self.gpu_matmul_warning_context(dtype):
      self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes))
  def test_coo_todense(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
    M = rng(shape, dtype)

    args = (M.data, M.row, M.col)
    todense = lambda *args: sparse.coo_todense(*args, shape=M.shape)

    self.assertArraysEqual(M.toarray(), todense(*args))
    with self.gpu_dense_conversion_warning_context(dtype):
      self.assertArraysEqual(M.toarray(), jit(todense)(*args))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes))
  def test_coo_fromdense(self, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_coo = scipy.sparse.coo_matrix(M)

    nse = M_coo.nnz
    index_dtype = jnp.int32
    fromdense = lambda M: sparse.coo_fromdense(M, nse=nse, index_dtype=jnp.int32)

    data, row, col = fromdense(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

    with self.gpu_dense_conversion_warning_context(dtype):
      data, indices, indptr = jit(fromdense)(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes
      for transpose in [True, False]))
  def test_coo_matvec(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    v_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
    M = rng(shape, dtype)
    v = v_rng(op(M).shape[1], dtype)

    args = (M.data, M.row, M.col, v)
    matvec = lambda *args: sparse.coo_matvec(*args, shape=M.shape, transpose=transpose)

    self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
    with self.gpu_matmul_warning_context(dtype):
      self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in all_dtypes
      for transpose in [True, False]))
  def test_coo_matmat(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    B_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
    M = rng(shape, dtype)
    B = B_rng((op(M).shape[1], 4), dtype)

    args = (M.data, M.row, M.col, B)
    matmat = lambda *args: sparse.coo_matmat(*args, shape=shape, transpose=transpose)

    self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
    with self.gpu_matmul_warning_context(dtype):
      self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

  def test_coo_matmat_layout(self):
    # Regression test for https://github.com/google/jax/issues/7533
    d = jnp.array([1.0, 2.0, 3.0, 4.0])
    i = jnp.array([0, 0, 1, 2])
    j = jnp.array([0, 2, 0, 0])
    shape = (3, 3)

    x = jnp.arange(9).reshape(3, 3).astype(d.dtype)

    def f(x):
      return sparse.coo_matmat(d, i, j, x.T, shape=shape)

    result = f(x)
    result_jit = jit(f)(x)

    self.assertAllClose(result, result_jit)

  @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
  def test_gpu_translation_rule(self):
    version = xla_bridge.get_backend().platform_version
    cuda_version = None if version == "<unknown>" else int(version.split()[-1])
    if cuda_version is None or cuda_version < 11000:
      self.assertFalse(cusparse and cusparse.is_supported)
      self.assertNotIn(sparse.csr_todense_p,
                       xla._backend_specific_translations["gpu"])
    else:
      self.assertTrue(cusparse and cusparse.is_supported)
      self.assertIn(sparse.csr_todense_p,
                    xla._backend_specific_translations["gpu"])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
         jtu.format_shape_dtype_string(shape, dtype), mat_type),
       "shape": shape, "dtype": dtype, "mat_type": mat_type}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for mat_type in ['csr', 'coo']))
  def test_extra_nse(self, shape, dtype, mat_type):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    nse = (M != 0).sum() + 5
    fromdense = getattr(sparse, f"{mat_type}_fromdense")
    todense = getattr(sparse, f"{mat_type}_todense")
    args = fromdense(M, nse=nse, index_dtype=jnp.int32)
    M_out = todense(*args, shape=M.shape)
    self.assertArraysEqual(M, M_out)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, row, col = sparse.coo_fromdense(M, nse=(M != 0).sum())
    f = lambda data: sparse.coo_todense(data, row, col, shape=M.shape)

    # Forward-mode
    primals, tangents = jax.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = jax.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_fromdense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    nse = (M != 0).sum()
    f = lambda M: sparse.coo_fromdense(M, nse=nse)

    # Forward-mode
    primals, tangents = jax.jvp(f, [M], [jnp.ones_like(M)])
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(tangents[0], jnp.ones(nse, dtype=dtype))
    self.assertEqual(tangents[1].dtype, dtypes.float0)
    self.assertEqual(tangents[2].dtype, dtypes.float0)

    # Reverse-mode
    primals, vjp_fun = jax.vjp(f, M)
    M_out, = vjp_fun(primals)
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(M_out, M)

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
        jtu.format_shape_dtype_string(shape, dtype),
        jtu.format_shape_dtype_string(bshape, dtype)),
       "shape": shape, "dtype": dtype, "bshape": bshape}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_matmul_ad(self, shape, dtype, bshape):
    coo_matmul = sparse.coo_matvec if len(bshape) == 1 else sparse.coo_matmat
    tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12}

    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())

    M = rng(shape, dtype)
    data, row, col = sparse.coo_fromdense(M, nse=(M != 0).sum())
    x = rng_b(bshape, dtype)
    xdot = rng_b(bshape, dtype)

    # Forward-mode with respect to the vector
    f_dense = lambda x: M @ x
    f_sparse = lambda x: coo_matmul(data, row, col, x, shape=M.shape)
    v_sparse, t_sparse = jax.jvp(f_sparse, [x], [xdot])
    v_dense, t_dense = jax.jvp(f_dense, [x], [xdot])
    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to the vector
    primals_dense, vjp_dense = jax.vjp(f_dense, x)
    primals_sparse, vjp_sparse = jax.vjp(f_sparse, x)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

    # Forward-mode with respect to nonzero elements of the matrix
    f_sparse = lambda data: coo_matmul(data, row, col, x, shape=M.shape)
    f_dense = lambda data: sparse.coo_todense(data, row, col, shape=M.shape) @ x
    data = rng((len(data),), data.dtype)
    data_dot = rng((len(data),), data.dtype)
    v_sparse, t_sparse = jax.jvp(f_sparse, [data], [data_dot])
    v_dense, t_dense = jax.jvp(f_dense, [data], [data_dot])

    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to nonzero elements of the matrix
    primals_dense, vjp_dense = jax.vjp(f_dense, data)
    primals_sparse, vjp_sparse = jax.vjp(f_sparse, data)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

class BCOOTest(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in all_dtypes
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_empty(self, shape, dtype, n_batch, n_dense):
    M = sparse.empty(shape, dtype=dtype, n_batch=n_batch, n_dense=n_dense)
    self.assertIsInstance(M, sparse.BCOO)
    self.assertEqual(M.nse, 0)
    self.assertEqual(M.n_batch, n_batch)
    self.assertEqual(M.n_dense, n_dense)
    self.assertEqual(M.dtype, dtype)
    self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in all_dtypes
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    n_sparse = M.ndim - n_batch - n_dense
    nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
    data_jit, indices_jit = jit(partial(sparse.bcoo_fromdense, nse=nse, n_batch=n_batch, n_dense=n_dense))(M)
    self.assertArraysEqual(data, data_jit)
    self.assertArraysEqual(indices, indices_jit)

    assert data.dtype == dtype
    assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
    assert indices.dtype == jnp.int32  # TODO: test passing this arg
    assert indices.shape == shape[:n_batch] + (nse, n_sparse)

    todense = partial(sparse.bcoo_todense, spinfo=BCOOInfo(shape))
    self.assertArraysEqual(M, todense(data, indices))
    self.assertArraysEqual(M, jit(todense)(data, indices))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    todense = partial(sparse.bcoo_todense, indices=indices, spinfo=BCOOInfo(shape))
    j1 = jax.jacfwd(todense)(data)
    j2 = jax.jacrev(todense)(data)
    hess = jax.hessian(todense)(data)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, M.shape + data.shape)
    self.assertEqual(hess.shape, M.shape + 2 * data.shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_fromdense_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))

    def fromdense(M):
      return sparse.bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)[0]
    data = fromdense(M)

    j1 = jax.jacfwd(fromdense)(M)
    j2 = jax.jacrev(fromdense)(M)
    hess = jax.hessian(fromdense)(M)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, data.shape + M.shape)
    self.assertEqual(hess.shape, data.shape + 2 * M.shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_dense_round_trip_batched(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    n_sparse = M.ndim - n_batch - n_dense
    nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))

    fromdense = partial(sparse.bcoo_fromdense, nse=nse, n_dense=n_dense)
    todense = partial(sparse.bcoo_todense, spinfo=BCOOInfo(shape[n_batch:]))
    for i in range(n_batch):
      fromdense = jax.vmap(fromdense)
      todense = jax.vmap(todense)

    data, indices = fromdense(M)

    assert data.dtype == dtype
    assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
    assert indices.dtype == jnp.int32  # TODO: test passing this arg
    assert indices.shape == shape[:n_batch] + (nse, n_sparse)

    self.assertArraysEqual(M, todense(data, indices))
    self.assertArraysEqual(M, jit(todense)(data, indices))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M)
    data2 = sparse.bcoo_extract(indices, M)
    self.assertArraysEqual(data, data2)
    data3 = jit(sparse.bcoo_extract)(indices, M)
    self.assertArraysEqual(data, data3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    extract = partial(sparse.bcoo_extract, indices)
    j1 = jax.jacfwd(extract)(M)
    j2 = jax.jacrev(extract)(M)
    hess = jax.hessian(extract)(M)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, data.shape + M.shape)
    self.assertEqual(hess.shape, data.shape + 2 * M.shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
    n_sparse = len(shape) - n_batch - n_dense
    rng = self.rng()
    sprng = rand_sparse(rng)
    M = sprng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    permutation = np.concatenate([
      rng.permutation(range(n_batch)),
      rng.permutation(range(n_batch, n_batch + n_sparse)),
      rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)

    M_T = M.transpose(permutation)
    trans = partial(sparse.bcoo_transpose, spinfo=BCOOInfo(shape), permutation=permutation)
    self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), spinfo=BCOOInfo(M_T.shape)))
    self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), spinfo=BCOOInfo(M_T.shape)))

    # test batched
    def trans(M):
      return M.transpose([p - n_batch for p in permutation[n_batch:]])
    for _ in range(n_batch):
      trans = jax.vmap(trans)
    Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
    self.assertArraysEqual(trans(M), trans(Msp).todense())

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
    n_sparse = len(shape) - n_batch - n_dense
    rng = self.rng()
    sprng = rand_sparse(self.rng())

    M = sprng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    permutation = np.concatenate([
      rng.permutation(range(n_batch)),
      rng.permutation(range(n_batch, n_batch + n_sparse)),
      rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)

    def f_sparse(data):
      return sparse.bcoo_transpose(data, indices, spinfo=BCOOInfo(shape), permutation=permutation)[0]

    jf_sparse = jax.jacfwd(f_sparse)(data)
    jr_sparse = jax.jacrev(f_sparse)(data)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    # TODO(jakevdp) also test against dense version?
    self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(1, len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    M1 = sparse.bcoo_todense(data, indices[:1], spinfo=BCOOInfo(M.shape))
    M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), spinfo=BCOOInfo(M.shape))
    self.assertAllClose(M1, M2)

    M3 = sparse.bcoo_todense(data[:1], indices, spinfo=BCOOInfo(M.shape))
    M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, spinfo=BCOOInfo(M.shape))
    self.assertAllClose(M3, M4)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": props.testcase_name(), "props": props}
      for props in _generate_bcoo_dot_general_properties(
        shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)],
        dtypes=jtu.dtypes.floating + jtu.dtypes.complex,
      )))
  def test_bcoo_dot_general(self, props: BcooDotGeneralProperties):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    def args_maker():
      lhs = rng_sparse(props.lhs_shape, props.dtype)
      rhs = rng(props.rhs_shape, props.dtype)
      data, indices = sparse.bcoo_fromdense(lhs, n_batch=props.n_batch, n_dense=props.n_dense)
      return data, indices, lhs, rhs

    def f_dense(data, indices, lhs, rhs):
      return lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)

    def f_sparse(data, indices, lhs, rhs):
      return sparse.bcoo_dot_general(data, indices, rhs, lhs_spinfo=BCOOInfo(lhs.shape),
                                     dimension_numbers=props.dimension_numbers)

    tol = {'float32': 3E-2} if jtu.device_under_test() == 'tpu' else {}
    self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
    self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, tol=tol)
    # TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
    # self._CompileAndCheck(f_sparse, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": props.testcase_name(), "props": props}
      for props in _generate_bcoo_dot_general_properties(
        shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)],
        dtypes=jtu.dtypes.floating + jtu.dtypes.complex,
      )))
  def test_bcoo_rdot_general(self, props: BcooDotGeneralProperties):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    lhs_shape, rhs_shape = props.rhs_shape, props.lhs_shape
    dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers)

    def args_maker():
      lhs = rng_sparse(lhs_shape, props.dtype)
      rhs = rng(rhs_shape, props.dtype)
      data, indices = sparse.bcoo_fromdense(rhs, n_batch=props.n_batch, n_dense=props.n_dense)
      return data, indices, lhs, rhs

    def f_dense(data, indices, lhs, rhs):
      return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)

    def f_sparse(data, indices, lhs, rhs):
      return sparse.bcoo_rdot_general(lhs, data, indices, rhs_spinfo=BCOOInfo(rhs.shape),
                                      dimension_numbers=dimension_numbers)

    tol = {'float32': 3E-2} if jtu.device_under_test() == 'tpu' else {}
    self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
    self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, tol=tol)
    # TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
    # self._CompileAndCheck(f_sparse, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, n_batch, n_dense),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers,
       "n_batch": n_batch, "n_dense": n_dense}
      for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
      ]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    X = rng_sparse(lhs_shape, dtype)
    data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
    Y = rng(rhs_shape, dtype)

    def f_dense(X, Y):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def f_sparse(data, indices, Y):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
                                     dimension_numbers=dimension_numbers)

    for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
      X = sparse.bcoo_todense(data, indices, spinfo=BCOOInfo(X.shape))
      self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, n_batch, n_dense),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers,
       "n_batch": n_batch, "n_dense": n_dense}
      for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
          ((4, 5), (5, 3), (([1], [0]), ([], [])), 0, 0),
          ((2, 4, 5), (2, 5, 3), (([2], [1]), ([0], [0])), 1, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
          # This requires contraction over dense dimensions, which is not yet implemented:
          # ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
      ]
      for dtype in jtu.dtypes.floating))
  def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
                               dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())

    X = rng_sparse(lhs_shape, dtype)
    data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
    Y = rng(rhs_shape, dtype)

    # gradient with respect to rhs
    def f_dense(Y):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def f_sparse(Y):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
                                     dimension_numbers=dimension_numbers)

    jf_dense = jax.jacfwd(f_dense)(Y)
    jr_dense = jax.jacrev(f_dense)(Y)
    jf_sparse = jax.jacfwd(f_sparse)(Y)
    jr_sparse = jax.jacrev(f_sparse)(Y)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
    self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
    self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)

    # gradient with respect to lhs
    def g_dense(X):
      return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

    def g_sparse(data):
      return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
                                     dimension_numbers=dimension_numbers)

    jf_dense = jax.jacfwd(g_dense)(X)
    jr_dense = jax.jacrev(g_dense)(X)
    jf_sparse = jax.jacfwd(g_sparse)(data)
    jr_sparse = jax.jacrev(g_sparse)(data)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    self.assertAllClose(jf_dense, jr_dense, rtol=tol)
    self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)

    # Extract the sparse jacobian from the dense & compare.
    def extract(X):
      return sparse.bcoo_extract(indices, X)
    for i in range(g_dense(X).ndim):
      extract = jax.vmap(extract)
    self.assertAllClose(extract(jf_dense), jf_sparse, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, n_batch, n_dense),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers,
       "n_batch": n_batch, "n_dense": n_dense}
      for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 0, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 0, 1),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 0, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 0, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 1, 2),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
      ]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_bcoo_dot_general_sampled(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_default(self.rng())
    sprng = rand_sparse(self.rng())
    out_shape = lax.dot_general(
      jnp.zeros(lhs_shape), jnp.zeros(rhs_shape),
      dimension_numbers=dimension_numbers).shape

    args_maker = lambda: [
      rng(lhs_shape, dtype), rng(rhs_shape, dtype),
      sparse.BCOO.fromdense(sprng(out_shape, dtype),
                            n_batch=n_batch, n_dense=n_dense).indices]

    def dense_fun(lhs, rhs, indices):
      AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
      return sparse.bcoo_extract(indices, AB)
    def sparse_fun(lhs, rhs, indices):
      return sparse.bcoo_dot_general_sampled(
                lhs, rhs, indices, dimension_numbers=dimension_numbers)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    self._CheckAgainstNumpy(dense_fun, sparse_fun, args_maker, tol=tol)
    # TODO: python_should_be_executing check occasionally fails... why?
    # self._CompileAndCheck(sparse_fun, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, n_batch, n_dense),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers,
       "n_batch": n_batch, "n_dense": n_dense}
      for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 1),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
      ]
      for dtype in jtu.dtypes.floating))
  def test_bcoo_dot_general_sampled_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense):
    rng = jtu.rand_default(self.rng())
    sprng = rand_sparse(self.rng())
    out_shape = lax.dot_general(
      jnp.zeros(lhs_shape), jnp.zeros(rhs_shape),
      dimension_numbers=dimension_numbers).shape

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    indices = sparse.BCOO.fromdense(sprng(out_shape, dtype),
                                    n_batch=n_batch, n_dense=n_dense).indices

    def dense_fun(lhs, rhs, indices):
      AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
      return sparse.bcoo_extract(indices, AB)
    def sparse_fun(lhs, rhs, indices):
      return sparse.bcoo_dot_general_sampled(
                lhs, rhs, indices, dimension_numbers=dimension_numbers)

    jf_dense = jax.jacfwd(dense_fun)(lhs, rhs, indices)
    jf_sparse = jax.jacfwd(sparse_fun)(lhs, rhs, indices)
    jr_dense = jax.jacrev(dense_fun)(lhs, rhs, indices)
    jr_sparse = jax.jacrev(sparse_fun)(lhs, rhs, indices)

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-3}

    self.assertAllClose(jf_sparse, jf_dense, atol=tol)
    self.assertAllClose(jr_sparse, jr_dense, atol=tol)
    self.assertAllClose(jf_sparse, jr_sparse, atol=tol)

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_swap={}_dims={}".format(
        jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
        jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
        swap, dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape,
       "lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch,
       "dimension_numbers": dimension_numbers, "swap": swap, "dtype": dtype}
      for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [
          # (batched) outer products (no contraction)
          ((5,), 0, (6,), 0, (([], []), ([], []))),
          ((3, 5), 0, (2, 4), 0, (([], []), ([], []))),
          ((3, 5), 1, (3, 4), 1, (([], []), ([0], [0]))),
          # (batched) vector-vector products
          ((5,), 0, (5,), 0, (([0], [0]), ([], []))),
          ((7,), 0, (7,), 0, (([0], [0]), ([], []))),
          ((5, 7), 1, (7,), 0, (([1], [0]), ([], []))),
          ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([0], [0]))),
          ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([], []))),
          ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([1], [0]))),
          ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([], []))),
          # (batched) matrix-vector products
          ((5, 7), 0, (7,), 0, (([1], [0]), ([], []))),
          ((2, 3, 4), 1, (4,), 0, (([2], [0]), ([], []))),
          ((2, 3, 4), 1, (2, 4), 1, (([2], [1]), ([0], [0]))),
          ((3, 2, 4), 1, (3, 4), 1, (([2], [1]), ([0], [0]))),
          ((2, 3, 4), 0, (2,), 0, (([0], [0]), ([], []))),
          # (batched) matrix-matrix products
          ((5, 7), 0, (7, 3), 0, (([1], [0]), ([], []))),
          ((2, 3, 4), 1, (4, 3), 0, (([2], [0]), ([], []))),
          ((2, 3, 4), 1, (2, 4, 3), 1, (([2], [1]), ([0], [0]))),
          # more general operations
          ((2, 3, 4, 3), 1, (2, 4, 3, 4), 1, (([2, 3], [1, 2]), ([0], [0]))),
          ((2, 3, 4, 3, 1), 2, (3, 2, 3, 4), 2, (([2, 3], [3, 2]), ([0, 1], [1, 0]))),
      ]
      for swap in [True, False]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers):
    if swap:
      dimension_numbers = tuple(d[::-1] for d in dimension_numbers)
      lhs_shape, rhs_shape = rhs_shape, lhs_shape
      lhs_n_batch, rhs_n_batch = rhs_n_batch, lhs_n_batch

    lhs_n_sparse = len(lhs_shape) - lhs_n_batch
    rhs_batch = dimension_numbers[1][1]
    lhs_contracting = dimension_numbers[0][0]
    should_error = (rhs_n_batch > len(rhs_batch) and lhs_n_sparse > len(lhs_contracting))

    sprng = rand_sparse(self.rng())
    def args_maker():
      x = sprng(lhs_shape, dtype)
      y = sprng(rhs_shape, dtype)
      xsp = sparse.BCOO.fromdense(x, n_batch=lhs_n_batch)
      ysp = sparse.BCOO.fromdense(y, n_batch=rhs_n_batch)
      return x, y, xsp, ysp

    def f_dense(x, y, xsp, ysp):
      return lax.dot_general(x, y, dimension_numbers=dimension_numbers)

    def f_sparse(x, y, xsp, ysp):
      shape = sparse.bcoo._dot_general_validated_shape(xsp.shape, ysp.shape, dimension_numbers)
      data, indices = sparse.bcoo_spdot_general(xsp.data, xsp.indices, ysp.data, ysp.indices,
                                                lhs_spinfo=xsp._info, rhs_spinfo=ysp._info,
                                                dimension_numbers=dimension_numbers)
      return sparse.bcoo_todense(data, indices, spinfo=BCOOInfo(shape))

    tol = {"complex128": 1E-14}
    if should_error:
      with self.assertRaisesRegex(ValueError, ".*cannot have unused batch dims on rhs with unused sparse dims on lhs."):
        f_sparse(*args_maker())
    else:
      self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
      self._CheckAgainstNumpy(jit(f_dense), jit(f_sparse), args_maker, tol=tol)
      # TODO(jakevdp): This occasionally fails python_should_be_executing check. Why?
      # self._CompileAndCheck(f_sparse, args_maker)

  def test_bcoo_spdot_general_nse(self):
    # vector-vector product -> nse=1
    x = sparse.BCOO.fromdense(jnp.arange(3))
    self.assertEqual((x @ x).nse, 1)

    # matrix-vector product -> nse matches matrix
    M = sparse.BCOO.fromdense(jnp.arange(6).reshape(2, 3))
    self.assertEqual((M @ x).nse, M.nse)

    # matrix-matrix product -> product of nse
    N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
    self.assertEqual((M @ N).nse, M.nse * N.nse)
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}[n_batch={}]_rhs_shape={}[n_batch={}]_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
               jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers,
       "lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch}
      for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [
          ((4, 5), 0, (5,), 0, (([1], [0]), ([], []))),
          ((2, 4, 5), 1, (5,), 0, (([2], [0]), ([], []))),
          ((4, 5), 0, (5, 3), 0, (([1], [0]), ([], []))),
          ((2, 4, 5), 1, (2, 5, 3), 1, (([2], [1]), ([0], [0]))),
      ]
      for dtype in jtu.dtypes.floating))
  def test_bcoo_spdot_general_ad(self, lhs_shape, rhs_shape, dtype,
                                 dimension_numbers, lhs_n_batch, rhs_n_batch):
    rng = rand_sparse(self.rng())

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)

    lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch)
    rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch)

    def f_dense(lhs_data, rhs_data):
      lhs = sparse.BCOO((lhs_data, lhs_sp.indices), shape=lhs_sp.shape).todense()
      rhs = sparse.BCOO((rhs_data, rhs_sp.indices), shape=rhs_sp.shape).todense()
      return (lhs @ rhs).sum()

    def f_sparse(lhs_data, rhs_data):
      lhs = sparse.BCOO((lhs_data, lhs_sp.indices), shape=lhs_sp.shape)
      rhs = sparse.BCOO((rhs_data, rhs_sp.indices), shape=rhs_sp.shape)
      return (lhs @ rhs).sum()

    tol = {}
    if jtu.device_under_test() == "tpu":
      tol = {np.float32: 5E-2}

    jf_dense_0 = jax.jacfwd(f_dense, argnums=0)(lhs_sp.data, rhs_sp.data)
    jf_sparse_0 = jax.jacfwd(f_sparse, argnums=0)(lhs_sp.data, rhs_sp.data)
    self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)

    jf_dense_1 = jax.jacfwd(f_dense, argnums=1)(lhs_sp.data, rhs_sp.data)
    jf_sparse_1 = jax.jacfwd(f_sparse, argnums=1)(lhs_sp.data, rhs_sp.data)
    self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)

    jf_dense_0, jf_dense_1 = jax.jacfwd(f_dense, argnums=(0, 1))(lhs_sp.data, rhs_sp.data)
    jf_sparse_0, jf_sparse_1 = jax.jacfwd(f_sparse, argnums=(0, 1))(lhs_sp.data, rhs_sp.data)
    self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)
    self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_in_axes={}".format(
        jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
        jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
        in_axes),
       "lhs_shape": lhs_shape, "lhs_n_batch": lhs_n_batch,
       "rhs_shape": rhs_shape, "rhs_n_batch": rhs_n_batch,
       "dtype": dtype, "in_axes": in_axes}
      for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, in_axes in [
        ((3, 5), 1, (3, 5), 1, 0),
        ((3, 4, 5), 1, (3, 5), 1, 0),
        ((3, 4, 5), 2, (3, 5), 1, 0),
        # TODO(jakevdp): test these once unequal batches are implemented
        # ((4, 5), 1, (5,), 0, (0, None)),
        # ((3, 4, 5), 1, (5,), 0, (0, None)),
        # ((4, 5), 0, (3, 5), 1, (None, 0)),
      ]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_bcoo_spmm_batched(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, in_axes):
    sprng = rand_sparse(self.rng())
    def args_maker():
      x = sprng(lhs_shape, dtype)
      y = sprng(rhs_shape, dtype)
      xsp = sparse.BCOO.fromdense(x, n_batch=lhs_n_batch)
      ysp = sparse.BCOO.fromdense(y, n_batch=rhs_n_batch)
      return x, y, xsp, ysp

    def f_dense(x, y, _, __):
      return jax.vmap(operator.matmul, in_axes=in_axes)(x, y)
    def f_sparse(_, __, x, y):
      return jax.vmap(operator.matmul, in_axes=in_axes)(x, y)

    args = args_maker()
    result_dense = f_dense(*args)
    result_sparse = f_sparse(*args)
    self.assertAllClose(result_dense, result_sparse.todense())
    result_sparse_jit = jax.jit(f_sparse)(*args)
    self.assertAllClose(result_dense, result_sparse_jit.todense())

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}_nse={}_remove_zeros={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse, remove_zeros),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
       "nse": nse, "remove_zeros": remove_zeros}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)
      for nse in [None, np.prod(shape) - 1]
      for remove_zeros in [True, False]))
  def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros):
    rng = self.rng()
    rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
    M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
    for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
      M.indices = M.indices.at[..., i].set(rng.randint(0, s, size=M.nse))
    dedupe = partial(M.sum_duplicates, nse=nse, remove_zeros=remove_zeros)
    jit_dedupe = jax.jit(dedupe)

    M_dedup = dedupe()
    self.assertAllClose(M.todense(), M_dedup.todense())
    if nse:
      self.assertEqual(M_dedup.nse, nse)

    if not nse:
      with self.assertRaisesRegex(ValueError, ".*nse argument"):
        jit_dedupe()
    else:
      M_dedup = jit_dedupe()
      self.assertAllClose(M.todense(), M_dedup.todense())
      self.assertEqual(M_dedup.nse, nse)

  def test_bcoo_sum_duplicates_inferred_nse(self):
    x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4)))
    self.assertEqual(x.nse, 3)
    y = x + x.T
    self.assertEqual(y.nse, 6)
    y2 = y.sum_duplicates()
    self.assertEqual(y2.nse, 3)
    self.assertArraysEqual(y.todense(), y2.todense())

  def test_bcoo_sum_duplicates_remove_zeros(self):
    data = jnp.array([0, 1, 0, 0])
    indices = jnp.array([[0], [1], [2], [3]])
    x = sparse.BCOO((data, indices), shape=(4,))
    self.assertEqual(x.nse, 4)

    y1 = x.sum_duplicates(remove_zeros=True)
    self.assertArraysEqual(x.todense(), y1.todense())
    self.assertEqual(y1.nse, 1)

    y2 = x.sum_duplicates(remove_zeros=False)
    self.assertArraysEqual(x.todense(), y2.todense())
    self.assertEqual(y2.nse, x.nse)

  def test_bcoo_sum_duplicates_padding(self):
    # Regression test for https://github.com/google/jax/issues/8163
    size = 3
    data = jnp.array([1, 0, 0])
    indices = jnp.array([1, size, size])[:, None]
    x = sparse.BCOO((data, indices), shape=(3,))
    y = x.sum_duplicates(nse=x.nse)
    self.assertArraysEqual(x.todense(), y.todense())
    self.assertArraysEqual(x.indices, y.indices)
    self.assertArraysEqual(x.data, y.data)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)
      for naxes in range(len(shape))
      for axes in itertools.combinations(range(len(shape)), naxes)))
  def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
    data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, spinfo=BCOOInfo(shape), axes=axes)
    result_dense = M.sum(axes)
    result_sparse = sparse.bcoo_todense(data_out, indices_out, spinfo=BCOOInfo(shape_out))
    tol = {np.float32: 1E-6, np.float64: 1E-14}
    self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
        jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
        jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
      }
      for lhs_shape, rhs_shape in [[(3,), (3,)],
                                   [(3, 4), (4,)],
                                   [(4,), (4, 5)],
                                   [(3, 4), (4, 5)],
                                   [(3, 4), (2, 4, 5)],
                                   [(2, 3, 4), (4, 5)],
                                   [(2, 3, 4), (2, 4, 5)]]
      for lhs_dtype in all_dtypes
      for rhs_dtype in all_dtypes))
  def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
    rng = jtu.rand_default(self.rng())
    lhs = jnp.array(rng(lhs_shape, lhs_dtype))
    rhs = jnp.array(rng(rhs_shape, rhs_dtype))

    # Note: currently, batch dimensions in matmul must correspond to batch
    # dimensions in the sparse representation.
    lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2))
    rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2))

    out1 = lhs @ rhs
    out2 = lhs_sp @ rhs
    out3 = lhs @ rhs_sp

    tol = {np.float64: 1E-13, np.complex128: 1E-13,
           np.float32: 1E-6, np.complex64: 1E-6}
    self.assertAllClose(out1, out2, rtol=tol)
    self.assertAllClose(out1, out3, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
        jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
        jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
        n_batch, n_dense),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "n_batch": n_batch, "n_dense": n_dense,
      }
      for lhs_shape, rhs_shape in [[(3,), ()], [(3,), (1,)], [(3,), (3,)],
                                   [(3, 4), ()], [(3, 4), (4,)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
                                   [(3, 4, 5), (4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
      for n_batch in range(len(lhs_shape) + 1)
      for n_dense in range(len(lhs_shape) + 1 - n_batch)
      for lhs_dtype in all_dtypes
      for rhs_dtype in all_dtypes))
  def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
    rng_lhs = rand_sparse(self.rng())
    rng_rhs = jtu.rand_default(self.rng())
    lhs = jnp.array(rng_lhs(lhs_shape, lhs_dtype))
    rhs = jnp.array(rng_rhs(rhs_shape, rhs_dtype))

    sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)

    out1 = lhs * rhs
    out2 = (sp(lhs) * rhs).todense()
    out3 = (rhs * sp(lhs)).todense()

    tol = {np.float64: 1E-13, np.complex128: 1E-13,
           np.float32: 1E-6, np.complex64: 1E-6}
    self.assertAllClose(out1, out2, rtol=tol)
    self.assertAllClose(out1, out3, rtol=tol)
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
        jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
        jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
        n_batch, n_dense),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "n_batch": n_batch, "n_dense": n_dense,
      }
      # TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul
      # supports inputs of differing rank.
      for lhs_shape, rhs_shape in [[(3,), (1,)], [(3,), (3,)],
                                   [(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
                                   [(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
      # TODO(jakevdp): add tests for batch & dense dimensions.
      for n_batch in range(len(lhs_shape) + 1)
      for n_dense in range(len(lhs_shape) + 1 - n_batch)
      for lhs_dtype in all_dtypes
      for rhs_dtype in all_dtypes))
  def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    lhs = jnp.array(rng(lhs_shape, lhs_dtype))
    rhs = jnp.array(rng(rhs_shape, rhs_dtype))

    sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)

    out1 = lhs * rhs
    out2 = (sp(lhs) * sp(rhs)).todense()

    tol = {np.float64: 1E-13, np.complex128: 1E-13,
           np.float32: 1E-6, np.complex64: 1E-6}
    self.assertAllClose(out1, out2, rtol=tol)

  def test_bcoo_mul_sparse_with_duplicates(self):
    # Regression test for https://github.com/google/jax/issues/8888
    indices = jnp.array([[0, 1, 0, 0, 1, 1],
                         [1, 0, 1, 2, 0, 2]]).T
    data = jnp.array([1, 2, 3, 4, 5, 6])
    mat = sparse.BCOO((data, indices), shape=(3, 3))
    self.assertArraysEqual((mat * mat).todense(), mat.todense() * mat.todense())

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_n_batch={}_n_dense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
       for shape in [(), (3,), (3, 5), (3, 5, 4)]
       for dtype in all_dtypes
       for n_batch in range(len(shape) + 1)
       for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    x = jnp.array(rng(shape, dtype))
    xsp = sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)

    self.assertArraysEqual(xsp[None].todense(), x[None])
    if len(shape) >= 1:
      self.assertArraysEqual(xsp[:, None].todense(), x[:, None])
      self.assertArraysEqual(xsp[:, None, None].todense(), x[:, None, None])
    if len(shape) >= 2:
      self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None])
      self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None])

  def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32):
    # This test checks that BCOO shape metadata interacts correctly with vmap.
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)

    def make_bcoo(M):
      return sparse.BCOO.fromdense(M, nse=np.prod(M.shape[:-1], dtype=int), n_dense=1)

    for _ in range(3):
      make_bcoo = jax.vmap(make_bcoo)
      Msp = make_bcoo(M)
      self.assertEqual(Msp.shape, M.shape)
      self.assertArraysEqual(Msp.todense(), M)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
       "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_bcoo_unbatch(self, shape, dtype, n_batch, n_dense):
    rng_sparse = rand_sparse(self.rng())
    M1 = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
    M2 = M1._unbatch()
    self.assertEqual(M2.n_batch, 0)
    self.assertEqual(M1.n_dense, M2.n_dense)
    self.assertEqual(M1.shape, M2.shape)
    self.assertEqual(M1.dtype, M2.dtype)
    self.assertArraysEqual(M1.todense(), M2.todense())

  def test_bcoo_bad_fillvals(self):
    # Extra values have 100 rather than zero. This lets us check that logic is
    # properly ignoring these indices.
    data = jnp.array([1, 2, 3, 100, 100])
    indices = jnp.array([1, 2, 3, 5, 5])[:, None]
    x_sp = sparse.BCOO((data, indices), shape=(5,))
    x_de = x_sp.todense()

    data = jnp.array([3, 2, 100, 100])
    indices = jnp.array([2, 3, 5, 5])[:, None]
    y_sp = sparse.BCOO((data, indices), shape=(5,))
    y_de = y_sp.todense()

    self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0]))
    self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0]))

    self.assertArraysEqual(x_sp.sum_duplicates().todense(), x_de)
    self.assertArraysEqual(y_sp.sum_duplicates().todense(), y_de)

    # reduce_sum:
    self.assertArraysEqual(x_sp.sum(), x_de.sum())

    # bcoo_dot_general
    self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)

    # bcoo_spdot_general
    self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
    self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)


class SparseGradTest(jtu.JaxTestCase):
  def test_sparse_grad(self):
    rng_sparse = rand_sparse(self.rng())
    rng = jtu.rand_default(self.rng())

    y = rng(5, "float32")
    X = rng_sparse((10, 5), "float32")
    Xsp = sparse.BCOO.fromdense(X)

    def f(X, y):
      return jnp.sum(X @ y)

    grad_dense = jax.grad(f, argnums=0)(X, y)
    grad_sparse = sparse.grad(f, argnums=0)(Xsp, y)

    # extract sparse gradient from dense gradient
    indices = tuple(Xsp.indices.T)
    grad_sparse_from_dense = jnp.zeros_like(grad_dense).at[indices].set(grad_dense[indices])

    self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)


class SparseObjectTest(jtu.JaxTestCase):
  def test_repr(self):
    M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
    self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")

    M_invalid = sparse.BCOO(([], []), shape=(100,))
    self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

  @parameterized.named_parameters(
    {"testcase_name": "_{}{}".format(cls.__name__, shape), "cls": cls, "shape": shape}
    for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]
    for shape in ([2, 5], [5, 3]))
  def test_empty(self, cls, shape):
    sparse_format = cls.__name__.lower()
    M = sparse.empty(shape, sparse_format=sparse_format)
    self.assertIsInstance(M, cls)
    self.assertEqual(M.nse, 0)
    self.assertArraysEqual(M.todense(), jnp.empty(shape))

  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
  def test_block_until_ready(self, Obj, shape=(5, 8), dtype=np.float32):
    rng = rand_sparse(self.rng(), post=Obj.fromdense)
    M = rng(shape, dtype)
    self.assertEqual(M.shape, M.block_until_ready().shape)
    self.assertArraysEqual(M.data, M.block_until_ready().data)
    self.assertArraysEqual(M.todense(), M.block_until_ready().todense())

  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [jnp.array, sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
  def test_todense(self, Obj, shape=(5, 8), dtype=np.float32):
    rng = rand_sparse(self.rng())
    M_dense = rng(shape, dtype)
    M = jnp.array(M_dense) if Obj is jnp.array else Obj.fromdense(M_dense)
    self.assertArraysEqual(sparse.todense(M), M_dense)
    self.assertArraysEqual(jit(sparse.todense)(M), M_dense)

  def test_todense_scalar(self):
    self.assertEqual(sparse.todense(1.0), 1.0)
    self.assertEqual(jit(sparse.todense)(1.0), 1.0)

  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [jnp.array, sparse.BCOO])
  def test_todense_batching(self, Obj, shape=(5, 8), dtype=np.float32):
    rng = rand_sparse(self.rng())
    M_dense = rng(shape, dtype)
    if Obj is sparse.BCOO:
      M = sparse.BCOO.fromdense(M_dense, n_batch=1)
    else:
      M = jnp.asarray(M_dense)
    self.assertArraysEqual(vmap(sparse.todense)(M), M_dense)
    self.assertArraysEqual(jit(vmap(sparse.todense))(M), M_dense)

  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [jnp.array, sparse.BCOO])
  def test_todense_ad(self, Obj, shape=(3,), dtype=np.float32):
    M_dense = jnp.array([1., 2., 3.])
    M = M_dense if Obj is jnp.array else Obj.fromdense(M_dense)
    bufs, tree = tree_util.tree_flatten(M)
    jac = jnp.eye(M.shape[0], dtype=M.dtype)
    jac1 = jax.jacfwd(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)
    jac2 = jax.jacrev(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)
    self.assertArraysEqual(jac1, jac2)
    self.assertArraysEqual(jac, jac2)

  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
  def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
    rng = rand_sparse(self.rng(), post=Obj.fromdense)
    M = rng(shape, dtype)

    assert isinstance(M, Obj)
    assert M.shape == shape
    assert M.size == np.prod(shape)
    assert M.ndim == len(shape)
    assert M.dtype == dtype
    assert M.nse == (M.todense() != 0).sum()
    assert M.data.dtype == dtype

    with self.assertRaises(TypeError):
      hash(M)

    if isinstance(M, sparse.CSR):
      assert len(M.data) == len(M.indices)
      assert len(M.indptr) == M.shape[0] + 1
    elif isinstance(M, sparse.CSC):
      assert len(M.data) == len(M.indices)
      assert len(M.indptr) == M.shape[1] + 1
    elif isinstance(M, sparse.COO):
      assert len(M.data) == len(M.row) == len(M.col)
    elif isinstance(M, sparse.BCOO):
      assert M.data.shape[M.n_batch] == M.indices.shape[-2]
      assert M.indices.shape[-1] == M.n_sparse
    else:
      raise ValueError("Obj={Obj} not expected.")

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
       "shape": shape, "dtype": dtype, "Obj": Obj}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
  def test_dense_round_trip(self, shape, dtype, Obj):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    Msparse = Obj.fromdense(M)
    self.assertArraysEqual(M, Msparse.todense())

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
       "shape": shape, "dtype": dtype, "Obj": Obj}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
  def test_transpose(self, shape, dtype, Obj):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    Msparse = Obj.fromdense(M)
    self.assertArraysEqual(M.T, Msparse.T.todense())

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}_bshape={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape),
       "shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
  def test_matmul(self, shape, dtype, Obj, bshape):
    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())
    M = rng(shape, dtype)
    Msp = Obj.fromdense(M)

    # Test matching type
    x = rng_b(bshape, dtype)
    x = jnp.asarray(x)
    self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)

    # Test mismatched type
    x = rng_b(bshape, np.int32)
    x = jnp.asarray(x)
    self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}({})".format(
        input_type.__name__,
        jtu.format_shape_dtype_string(shape, dtype)),
       "input_type": input_type, "shape": shape, "dtype": dtype}
      for input_type in [scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_bcoo_from_scipy_sparse(self, input_type, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_sparse = input_type(M)
    M_bcoo = sparse.BCOO.from_scipy_sparse(M_sparse)
    self.assertArraysEqual(M, M_bcoo.todense())

  def test_bcoo_methods(self):
    M = jnp.arange(12).reshape(3, 4)
    Msp = sparse.BCOO.fromdense(M)

    self.assertArraysEqual(-M, (-Msp).todense())

    self.assertArraysEqual(2 * M, (2 * Msp).todense())
    self.assertArraysEqual(M * 2, (Msp * 2).todense())

    self.assertArraysEqual(M + M, (Msp + Msp).todense())

    self.assertArraysEqual(M.sum(0), Msp.sum(0).todense())
    self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
    self.assertArraysEqual(M.sum(), Msp.sum())

class SparseRandomTest(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format(
        jtu.format_shape_dtype_string(shape, dtype), indices_dtype, n_batch, n_dense),
       "shape": shape, "dtype": dtype, "indices_dtype": indices_dtype,
       "n_batch": n_batch, "n_dense": n_dense}
      for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
      for dtype in jtu.dtypes.floating
      for indices_dtype in jtu.dtypes.integer
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) + 1 - n_batch)))
  def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense):
    key = jax.random.PRNGKey(1701)
    mat = sparse.random_bcoo(
        key, shape=shape, dtype=dtype, indices_dtype=indices_dtype,
        n_batch=n_batch, n_dense=n_dense)

    mat_dense = mat.todense()
    self.assertEqual(mat_dense.shape, shape)
    self.assertEqual(mat_dense.dtype, dtype)
    self.assertEqual(mat.indices.dtype, indices_dtype)

    n_sparse = len(shape) - n_batch - n_dense
    batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse])

    approx_expected_num_nonzero = (
      np.ceil(0.2 * np.prod(sparse_shape))
      * np.prod(batch_shape) * np.prod(dense_shape))
    num_nonzero = (mat_dense != 0).sum()
    self.assertAlmostEqual(num_nonzero, approx_expected_num_nonzero, delta=2)


if __name__ == "__main__":
  absltest.main(testLoader=jtu.JaxTestLoader())
back to top