https://github.com/tensorly/tensorly
testing.py
import sys
from inspect import getfullargspec
import numpy as np
from tensorly import backend as T
def assert_array_equal(a, b, *args, **kwargs):
np.testing.assert_array_equal(T.to_numpy(a), T.to_numpy(b), *args, **kwargs)
def assert_array_almost_equal(a, b, *args, **kwargs):
np.testing.assert_array_almost_equal(T.to_numpy(a), T.to_numpy(b), *args, **kwargs)
def assert_allclose(
actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg="", verbose=True
):
"""Check if two arrays are equal up to a given relevant and absolute tolerance.
See the `NumPy documentation <https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html>`_ for more details.
"""
np.testing.assert_allclose(
T.to_numpy(actual),
T.to_numpy(desired),
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
err_msg=err_msg,
verbose=verbose,
)
def assert_equal(actual, desired, *args, **kwargs):
def _tensor_to_numpy(x):
if T.is_tensor(x):
x = T.to_numpy(x)
return x[0] if x.shape == (1,) else x
return x
np.testing.assert_equal(
_tensor_to_numpy(actual), _tensor_to_numpy(desired), *args, **kwargs
)
def _get_defaultkwargs(func):
"""Returns a dictionary containing all of the input function's arguments with default values."""
argspec = getfullargspec(func)
arguments = argspec.args
defaults = argspec.defaults
kwonlydefaults = argspec.kwonlydefaults
if defaults is None:
defaults = tuple()
if kwonlydefaults is None:
kwonlydefaults = {}
start_defaults_idx = len(arguments) - len(defaults)
arguments = arguments[start_defaults_idx:]
default_args = {argument: default for argument, default in zip(arguments, defaults)}
return {
**default_args,
**kwonlydefaults,
}
def _get_decomposition_checker(supposed_kwargs, output_length):
"""Factory function whose output asserts that all entries in ``supposed_kwargs`` match entries in the kwargs-dictionary.
This is a utility function used to automate testing of the object oriented interface.
Arguments
---------
supposed_kwargs : dict
All keyword arguments that should be in the kwargs dict whenever the output function is called
and their supposed value.
output_length : int
The number of outputs from the function
Returns
-------
function
Function that iterates over the supposed_kwarg dictionary and checks that each key and value
matches those of the function call.
"""
def decomposition_function(*args, **kwargs):
for argument, supposed_default in supposed_kwargs.items():
np.testing.assert_(
argument in kwargs,
"All arguments with a default must be passed as keyword argument when the decomposition class calls the decomposition function",
)
np.testing.assert_(kwargs[argument] == supposed_default)
return [None for _ in range(output_length)]
return decomposition_function
def assert_class_wrapper_correctly_passes_arguments(
monkeypatch,
decomposition_function,
DecompositionClass,
ignore_args=None,
decomposition_output_length=2,
**extra_args
):
"""Used to ensure that all arguments are passed correctly from the decomposition class to the decomposition function
This code must be used in a test ran with the PyTest framework.
Arguments:
----------
monkeypatch : pytest.monkeypatch
Monkeypatch fixture
decomposition_function : Function
Decomposition function wrapped by the class
DecompositionClass : Class
Class that wraps the function
ignore_args : iterable
List of arguments that shouldn't be checked
decomposition_output_length : int
Number of outputs from the decomposition function
**extra_args
Extra keyword-arguments passed to the decomposition class
Example:
--------
Here is a simple example to check that the CP class' arguments match that of the parafac function.
>>> from tensorly.decomposition import parafac, CP
... def test_cp(monkeypatch):
... assert_class_wrapper_correctly_passes_arguments(monkeypatch, parafac, CP, ignore_args={'return_errors'}, rank=3)
"""
kwargs = _get_defaultkwargs(decomposition_function)
test_kwargs = {
argument: "this_is_used_to_test_correct_passing_of_arguments"
for argument in kwargs
}
if ignore_args is not None:
for arg in ignore_args:
del test_kwargs[arg]
decomposition_checker = _get_decomposition_checker(
test_kwargs, decomposition_output_length
)
decomposition_module = sys.modules[decomposition_function.__module__]
monkeypatch.setattr(
decomposition_module, decomposition_function.__name__, decomposition_checker
)
DecompositionClass(**extra_args, **test_kwargs).fit(None)
assert_ = np.testing.assert_
assert_raises = np.testing.assert_raises