https://github.com/hgomersall/pyFFTW
Tip revision: 159c4adbe57345fa49f1bfae35ce0afe06188394 authored by Henry Gomersall on 06 February 2023, 14:55:02 UTC
[FIX] removed development debugging output.
[FIX] removed development debugging output.
Tip revision: 159c4ad
test_pyfftw_numpy_interface.py
# Copyright 2015 Knowledge Economy Developments Ltd
#
# Henry Gomersall
# heng@kedevelopments.co.uk
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
from pyfftw import interfaces, _supported_types, _all_types_np
from .test_pyfftw_base import run_test_suites, np_fft
from ._get_default_args import get_default_args
from distutils.version import LooseVersion
import unittest
import numpy
import warnings
import copy
warnings.filterwarnings('always')
if LooseVersion(numpy.version.version) <= LooseVersion('1.6.2'):
# We overwrite the broken _cook_nd_args with a fixed version.
from ._cook_nd_args import _cook_nd_args
numpy.fft.fftpack._cook_nd_args = _cook_nd_args
complex_dtypes = []
real_dtypes = []
if '32' in _supported_types:
complex_dtypes.extend([numpy.complex64]*2)
real_dtypes.extend([numpy.float16, numpy.float32])
if '64' in _supported_types:
complex_dtypes.append(numpy.complex128)
real_dtypes.append(numpy.float64)
if 'ld' in _supported_types:
complex_dtypes.append(numpy.clongdouble)
real_dtypes.append(numpy.longdouble)
def make_complex_data(shape, dtype):
ar, ai = dtype(numpy.random.randn(2, *shape))
return ar + 1j*ai
def make_real_data(shape, dtype):
return dtype(numpy.random.randn(*shape))
def _numpy_fft_has_norm_kwarg():
"""returns True if numpy's fft supports the norm keyword argument
This should be true for numpy >= 1.10
"""
# return LooseVersion(numpy.version.version) >= LooseVersion('1.10')
try:
np_fft.fft(numpy.ones(4), norm=None)
return True
except TypeError:
return False
if _numpy_fft_has_norm_kwarg() and numpy.__version__ < '1.13':
# use version of numpy.fft.rfft* with normalisation bug fixed
# The patched version here, corresponds to the following bugfix PR:
# https://github.com/numpy/numpy/pull/8445
from numpy.fft import fftpack as fftpk
def rfft_fix(a, n=None, axis=-1, norm=None):
# from numpy.fft import fftpack_lite as fftpack
# from numpy.fft.fftpack import _raw_fft, _unitary, _real_fft_cache
a = numpy.array(a, copy=True, dtype=float)
output = fftpk._raw_fft(a, n, axis, fftpk.fftpack.rffti,
fftpk.fftpack.rfftf, fftpk._real_fft_cache)
if fftpk._unitary(norm):
if n is None:
n = a.shape[axis]
output *= 1 / numpy.sqrt(n)
return output
def rfftn_fix(a, s=None, axes=None, norm=None):
a = numpy.array(a, copy=True, dtype=float)
s, axes = fftpk._cook_nd_args(a, s, axes)
a = rfft_fix(a, s[-1], axes[-1], norm)
for ii in range(len(axes)-1):
a = fftpk.fft(a, s[ii], axes[ii], norm)
return a
def rfft2_fix(a, s=None, axes=(-2, -1), norm=None):
return rfftn_fix(a, s, axes, norm)
np_fft.rfft = rfft_fix
np_fft.rfft2 = rfft2_fix
np_fft.rfftn = rfftn_fix
functions = {
'fft': 'complex',
'ifft': 'complex',
'rfft': 'r2c',
'irfft': 'c2r',
'rfftn': 'r2c',
'hfft': 'c2r',
'ihfft': 'r2c',
'irfftn': 'c2r',
'rfft2': 'r2c',
'irfft2': 'c2r',
'fft2': 'complex',
'ifft2': 'complex',
'fftn': 'complex',
'ifftn': 'complex'}
acquired_names = ('fftfreq', 'fftshift', 'ifftshift')
if LooseVersion(numpy.version.version) >= LooseVersion('1.8'):
acquired_names += ('rfftfreq', )
class InterfacesNumpyFFTTestModule(unittest.TestCase):
''' A really simple test suite to check the module works as expected.
'''
def test_acquired_names(self):
for each_name in acquired_names:
numpy_fft_attr = getattr(numpy.fft, each_name)
acquired_attr = getattr(interfaces.numpy_fft, each_name)
self.assertIs(numpy_fft_attr, acquired_attr)
class InterfacesNumpyFFTTestFFT(unittest.TestCase):
io_dtypes = {
'complex': (complex_dtypes, make_complex_data),
'r2c': (real_dtypes, make_real_data),
'c2r': (complex_dtypes, make_complex_data)}
validator_module = np_fft
test_interface = interfaces.numpy_fft
func = 'fft'
axes_kw = 'axis'
threads_arg_name = 'threads'
overwrite_input_flag = 'overwrite_input'
default_s_from_shape_slicer = slice(-1, None)
if numpy.__version__ >= '1.20.0':
test_shapes = (
((100,), {}),
((128, 64), {'axis': 0}),
((128, 32), {'axis': -1}),
((59, 100), {}),
((59, 99), {'axis': -1}),
((59, 99), {'axis': 0}),
((32, 32, 4), {'axis': 1}),
((32, 32, 2), {'axis': 1, 'norm': 'ortho'}),
((32, 32, 2), {'axis': 1, 'norm': None}),
((32, 32, 2), {'axis': 1, 'norm': 'backward'}),
((32, 32, 2), {'axis': 1, 'norm': 'forward'}),
((64, 128, 16), {}),
)
else:
test_shapes = (
((100,), {}),
((128, 64), {'axis': 0}),
((128, 32), {'axis': -1}),
((59, 100), {}),
((59, 99), {'axis': -1}),
((59, 99), {'axis': 0}),
((32, 32, 4), {'axis': 1}),
((32, 32, 2), {'axis': 1, 'norm': 'ortho'}),
((32, 32, 2), {'axis': 1, 'norm': None}),
((64, 128, 16), {}),
)
# invalid_s_shapes is:
# (size, invalid_args, error_type, error_string)
invalid_args = (
((100,), ((100, 200),), TypeError, ''),
((100, 200), ((100, 200),), TypeError, ''),
((100,), (100, (-2, -1)), TypeError, ''),
((100,), (100, -20), IndexError, ''))
realinv = False
has_norm_kwarg = _numpy_fft_has_norm_kwarg()
@property
def test_data(self):
for test_shape, kwargs in self.test_shapes:
axes = self.axes_from_kwargs(kwargs)
s = self.s_from_kwargs(test_shape, kwargs)
if not self.has_norm_kwarg and 'norm' in kwargs:
kwargs.pop('norm')
if self.realinv:
test_shape = list(test_shape)
test_shape[axes[-1]] = test_shape[axes[-1]]//2 + 1
test_shape = tuple(test_shape)
yield test_shape, s, kwargs
def __init__(self, *args, **kwargs):
super(InterfacesNumpyFFTTestFFT, self).__init__(*args, **kwargs)
# Assume python 3, but keep backwards compatibility
if not hasattr(self, 'assertRaisesRegex'):
self.assertRaisesRegex = self.assertRaisesRegexp
def validate(self, array_type, test_shape, dtype,
s, kwargs, copy_func=copy.copy):
# Do it without the cache
# without:
interfaces.cache.disable()
self._validate(array_type, test_shape, dtype, s, kwargs,
copy_func=copy_func)
def munge_input_array(self, array, kwargs):
return array
def _validate(self, array_type, test_shape, dtype,
s, kwargs, copy_func=copy.copy):
input_array = self.munge_input_array(
array_type(test_shape, dtype), kwargs)
orig_input_array = copy_func(input_array)
np_input_array = numpy.asarray(input_array)
# Why are long double inputs copied to double precision? It's what
# numpy silently does anyways as of v1.10 but helps with backward
# compatibility and scipy.
# https://github.com/pyFFTW/pyFFTW/pull/189#issuecomment-356449731
if np_input_array.dtype == 'clongdouble':
np_input_array = numpy.complex128(input_array)
elif np_input_array.dtype == 'longdouble':
np_input_array = numpy.float64(input_array)
with warnings.catch_warnings(record=True) as w:
# We catch the warnings so as to pick up on when
# a complex array is turned into a real array
if 'axes' in kwargs:
validator_kwargs = {'axes': kwargs['axes']}
elif 'axis' in kwargs:
validator_kwargs = {'axis': kwargs['axis']}
else:
validator_kwargs = {}
if self.has_norm_kwarg and 'norm' in kwargs:
validator_kwargs['norm'] = kwargs['norm']
try:
test_out_array = getattr(self.validator_module, self.func)(
copy_func(np_input_array), s, **validator_kwargs)
except Exception as e:
interface_exception = None
try:
getattr(self.test_interface, self.func)(
copy_func(input_array), s, **kwargs)
except Exception as _interface_exception:
# It's necessary to assign the exception to the
# already defined variable in Python 3.
# See http://www.python.org/dev/peps/pep-3110/#semantic-changes
interface_exception = _interface_exception
# If the test interface raised, so must this.
self.assertEqual(type(interface_exception), type(e),
msg='Interface exception raised. ' +
'Testing for: ' + repr(e))
return
try:
output_array = getattr(self.test_interface, self.func)(
copy_func(np_input_array), s, **kwargs)
except NotImplementedError as e:
# check if exception due to missing precision
msg = repr(e)
if 'Rebuild pyFFTW with support for' in msg:
self.skipTest(msg)
else:
raise
if (functions[self.func] == 'r2c'):
if numpy.iscomplexobj(input_array):
if len(w) > 0:
# Make sure a warning is raised
self.assertIs(
w[-1].category, numpy.ComplexWarning)
self.assertTrue(
numpy.allclose(output_array, test_out_array,
rtol=1e-2, atol=1e-4))
if _all_types_np.get(np_input_array.real.dtype, "") in _supported_types:
# supported precisions should not be converted
self.assertEqual(np_input_array.real.dtype,
output_array.real.dtype)
if (not self.overwrite_input_flag in kwargs or
not kwargs[self.overwrite_input_flag]):
self.assertTrue(numpy.allclose(input_array,
orig_input_array))
return output_array
def axes_from_kwargs(self, kwargs):
default_args = get_default_args(
getattr(self.test_interface, self.func))
if 'axis' in kwargs:
axes = (kwargs['axis'],)
elif 'axes' in kwargs:
axes = kwargs['axes']
if axes is None:
axes = default_args['axes']
else:
if 'axis' in default_args:
# default 1D
axes = (default_args['axis'],)
else:
# default nD
axes = default_args['axes']
if axes is None:
axes = (-1,)
return axes
def s_from_kwargs(self, test_shape, kwargs):
''' Return either a scalar s or a tuple depending on
whether axis or axes is specified
'''
default_args = get_default_args(
getattr(self.test_interface, self.func))
if 'axis' in kwargs:
s = test_shape[kwargs['axis']]
elif 'axes' in kwargs:
axes = kwargs['axes']
if axes is not None:
s = []
for each_axis in axes:
s.append(test_shape[each_axis])
else:
# default nD
s = []
try:
for each_axis in default_args['axes']:
s.append(test_shape[each_axis])
except TypeError:
try:
s = list(test_shape[
self.default_s_from_shape_slicer])
except TypeError:
# We had an integer as the default, so force
# it to be a list
s = [test_shape[self.default_s_from_shape_slicer]]
else:
if 'axis' in default_args:
# default 1D
s = test_shape[default_args['axis']]
else:
# default nD
s = []
try:
for each_axis in default_args['axes']:
s.append(test_shape[each_axis])
except TypeError:
s = None
return s
def test_valid(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
s = None
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def test_on_non_numpy_array(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
array_type = (lambda test_shape, dtype:
dtype_tuple[1](test_shape, dtype).tolist())
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
s = None
self.validate(array_type,
test_shape, dtype, s, kwargs)
def test_fail_on_invalid_s_or_axes_or_norm(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, args, exception, e_str in self.invalid_args:
input_array = dtype_tuple[1](test_shape, dtype)
if len(args) > 2 and not self.has_norm_kwarg:
# skip tests involving norm argument if it isn't available
continue
self.assertRaisesRegex(exception, e_str,
getattr(self.test_interface, self.func),
*((input_array,) + args))
def test_same_sized_s(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def test_bigger_s(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
try:
for each_axis, length in enumerate(s):
s[each_axis] += 2
except TypeError:
s += 2
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def test_smaller_s(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
try:
for each_axis, length in enumerate(s):
s[each_axis] -= 2
except TypeError:
s -= 2
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def check_arg(self, arg, arg_test_values, array_type, test_shape,
dtype, s, kwargs):
'''Check that the correct arg is passed to the builder'''
# We trust the builders to work as expected when passed
# the correct arg (the builders have their own unittests).
return_values = []
input_array = array_type(test_shape, dtype)
def fake_fft(*args, **kwargs):
return_values.append((args, kwargs))
return (args, kwargs)
try:
# Replace the function that is to be used
real_fft = getattr(self.test_interface, self.func)
setattr(self.test_interface, self.func, fake_fft)
_kwargs = kwargs.copy()
for each_value in arg_test_values:
_kwargs[arg] = each_value
builder_args = getattr(self.test_interface, self.func)(
input_array.copy(), s, **_kwargs)
self.assertTrue(builder_args[1][arg] == each_value)
# make sure it was called
self.assertTrue(len(return_values) > 0)
except:
raise
finally:
# Make sure we set it back
setattr(self.test_interface, self.func, real_fft)
# Validate it aswell
for each_value in arg_test_values:
_kwargs[arg] = each_value
builder_args = getattr(self.test_interface, self.func)(
input_array.copy(), s, **_kwargs)
self.validate(array_type, test_shape, dtype, s, _kwargs)
def test_auto_align_input(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
self.check_arg('auto_align_input', (True, False),
dtype_tuple[1], test_shape, dtype, s, kwargs)
def test_auto_contiguous_input(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
self.check_arg('auto_contiguous', (True, False),
dtype_tuple[1], test_shape, dtype, s, kwargs)
def test_bigger_and_smaller_s(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
i = -1
for test_shape, s, kwargs in self.test_data:
try:
for each_axis, length in enumerate(s):
s[each_axis] += i * 2
i *= i
except TypeError:
s += i * 2
i *= i
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def test_dtype_coercian(self):
# Make sure we input a dtype that needs to be coerced
if functions[self.func] == 'r2c':
dtype_tuple = self.io_dtypes['complex']
else:
dtype_tuple = self.io_dtypes['r2c']
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
s = None
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
def test_planner_effort(self):
'''Test the planner effort arg
'''
dtype_tuple = self.io_dtypes[functions[self.func]]
test_shape = (16,)
for dtype in dtype_tuple[0]:
s = None
if self.axes_kw == 'axis':
kwargs = {'axis': -1}
else:
kwargs = {'axes': (-1,)}
for each_effort in ('FFTW_ESTIMATE', 'FFTW_MEASURE',
'FFTW_PATIENT', 'FFTW_EXHAUSTIVE'):
kwargs['planner_effort'] = each_effort
self.validate(
dtype_tuple[1], test_shape, dtype, s, kwargs)
kwargs['planner_effort'] = 'garbage'
self.assertRaisesRegex(ValueError, 'Invalid planner effort',
self.validate,
*(dtype_tuple[1], test_shape, dtype, s, kwargs))
def test_threads_arg(self):
'''Test the threads argument
'''
dtype_tuple = self.io_dtypes[functions[self.func]]
test_shape = (16,)
for dtype in dtype_tuple[0]:
s = None
if self.axes_kw == 'axis':
kwargs = {'axis': -1}
else:
kwargs = {'axes': (-1,)}
self.check_arg(self.threads_arg_name, (1, 2, 5, 10),
dtype_tuple[1], test_shape, dtype, s, kwargs)
kwargs[self.threads_arg_name] = 'bleh'
# Should not work
self.assertRaises(TypeError,
self.validate,
*(dtype_tuple[1], test_shape, dtype, s, kwargs))
def test_overwrite_input(self):
'''Test the overwrite_input flag
'''
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, _kwargs in self.test_data:
s = None
kwargs = _kwargs.copy()
self.validate(dtype_tuple[1], test_shape, dtype, s, kwargs)
self.check_arg(self.overwrite_input_flag, (True, False),
dtype_tuple[1], test_shape, dtype, s, kwargs)
def test_input_maintained(self):
'''Test to make sure the input is maintained by default.
'''
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
input_array = dtype_tuple[1](test_shape, dtype)
orig_input_array = input_array.copy()
getattr(self.test_interface, self.func)(
input_array, s, **kwargs)
self.assertTrue(
numpy.alltrue(input_array == orig_input_array))
def test_on_non_writeable_array_issue_92(self):
'''Test to make sure that locked arrays work.
Regression test for issue 92.
'''
def copy_with_writeable(array_to_copy):
array_copy = array_to_copy.copy()
array_copy.flags.writeable = array_to_copy.flags.writeable
return array_copy
dtype_tuple = self.io_dtypes[functions[self.func]]
def array_type(test_shape, dtype):
a = dtype_tuple[1](test_shape, dtype)
a.flags.writeable = False
return a
for dtype in dtype_tuple[0]:
for test_shape, s, kwargs in self.test_data:
s = None
self.validate(array_type,
test_shape, dtype, s, kwargs,
copy_func=copy_with_writeable)
def test_overwrite_input_for_issue_92(self):
'''Tests that trying to overwrite a locked array fails.
'''
a = numpy.zeros((4,))
a.flags.writeable = False
self.assertRaisesRegex(
ValueError,
'overwrite_input cannot be True when the ' +
'input array flags.writeable is False',
interfaces.numpy_fft.fft,
a, overwrite_input=True)
class InterfacesNumpyFFTTestIFFT(InterfacesNumpyFFTTestFFT):
func = 'ifft'
class InterfacesNumpyFFTTestRFFT(InterfacesNumpyFFTTestFFT):
func = 'rfft'
class InterfacesNumpyFFTTestIRFFT(InterfacesNumpyFFTTestFFT):
func = 'irfft'
realinv = True
class InterfacesNumpyFFTTestHFFT(InterfacesNumpyFFTTestFFT):
func = 'hfft'
realinv = True
class InterfacesNumpyFFTTestIHFFT(InterfacesNumpyFFTTestFFT):
func = 'ihfft'
class InterfacesNumpyFFTTestFFT2(InterfacesNumpyFFTTestFFT):
axes_kw = 'axes'
func = 'ifft2'
if numpy.__version__ >= '1.20.0':
test_shapes = (
((128, 64), {'axes': None}),
((128, 32), {'axes': None}),
((128, 32, 4), {'axes': (0, 2)}),
((59, 100), {'axes': (-2, -1)}),
((32, 32), {'axes': (-2, -1), 'norm': 'ortho'}),
((32, 32), {'axes': (-2, -1), 'norm': None}),
((32, 32), {'axes': (-2, -1), 'norm': 'backward'}),
((32, 32), {'axes': (-2, -1), 'norm': 'forward'}),
((64, 128, 16), {'axes': (0, 2)}),
((4, 6, 8, 4), {'axes': (0, 3)}),
)
else:
test_shapes = (
((128, 64), {'axes': None}),
((128, 32), {'axes': None}),
((128, 32, 4), {'axes': (0, 2)}),
((59, 100), {'axes': (-2, -1)}),
((32, 32), {'axes': (-2, -1), 'norm': 'ortho'}),
((32, 32), {'axes': (-2, -1), 'norm': None}),
((64, 128, 16), {'axes': (0, 2)}),
((4, 6, 8, 4), {'axes': (0, 3)}),
)
invalid_args = (
((100,), ((100, 200),), ValueError, ''),
((100, 200), ((100, 200, 100),), ValueError, ''),
((100,), ((100, 200), (-3, -2, -1)), ValueError, ''),
((100, 200), (100, -1), TypeError, ''),
((100, 200), ((100, 200), (-3, -2)), IndexError, 'Invalid axes'),
((100, 200), ((100,), (-3,)), IndexError, 'Invalid axes'),
# pass invalid normalisation string
((100, 200), ((100,), (-3,), 'invalid_norm'), ValueError, ''))
def test_shape_and_s_different_lengths(self):
dtype_tuple = self.io_dtypes[functions[self.func]]
for dtype in dtype_tuple[0]:
for test_shape, s, _kwargs in self.test_data:
kwargs = copy.copy(_kwargs)
try:
s = s[1:]
except TypeError:
self.skipTest('Not meaningful test on 1d arrays.')
del kwargs['axes']
self.validate(dtype_tuple[1],
test_shape, dtype, s, kwargs)
class InterfacesNumpyFFTTestIFFT2(InterfacesNumpyFFTTestFFT2):
func = 'ifft2'
class InterfacesNumpyFFTTestRFFT2(InterfacesNumpyFFTTestFFT2):
func = 'rfft2'
class InterfacesNumpyFFTTestIRFFT2(InterfacesNumpyFFTTestFFT2):
func = 'irfft2'
realinv = True
class InterfacesNumpyFFTTestFFTN(InterfacesNumpyFFTTestFFT2):
func = 'ifftn'
if numpy.__version__ >= '1.20.0':
test_shapes = (
((128, 32, 4), {'axes': None}),
((64, 128, 16), {'axes': (0, 1, 2)}),
((4, 6, 8, 4), {'axes': (0, 3, 1)}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': None}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'backward'}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'forward'}),
((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
)
else:
test_shapes = (
((128, 32, 4), {'axes': None}),
((64, 128, 16), {'axes': (0, 1, 2)}),
((4, 6, 8, 4), {'axes': (0, 3, 1)}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
((4, 6, 4, 4), {'axes': (0, 3, 1), 'norm': None}),
((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
)
class InterfacesNumpyFFTTestIFFTN(InterfacesNumpyFFTTestFFTN):
func = 'ifftn'
class InterfacesNumpyFFTTestRFFTN(InterfacesNumpyFFTTestFFTN):
func = 'rfftn'
class InterfacesNumpyFFTTestIRFFTN(InterfacesNumpyFFTTestFFTN):
func = 'irfftn'
realinv = True
test_cases = (
InterfacesNumpyFFTTestModule,
InterfacesNumpyFFTTestFFT,
InterfacesNumpyFFTTestIFFT,
InterfacesNumpyFFTTestRFFT,
InterfacesNumpyFFTTestIRFFT,
InterfacesNumpyFFTTestHFFT,
InterfacesNumpyFFTTestIHFFT,
InterfacesNumpyFFTTestFFT2,
InterfacesNumpyFFTTestIFFT2,
InterfacesNumpyFFTTestRFFT2,
InterfacesNumpyFFTTestIRFFT2,
InterfacesNumpyFFTTestFFTN,
InterfacesNumpyFFTTestIFFTN,
InterfacesNumpyFFTTestRFFTN,
InterfacesNumpyFFTTestIRFFTN,)
#test_set = {'InterfacesNumpyFFTTestHFFT': ('test_valid',)}
test_set = None
if __name__ == '__main__':
run_test_suites(test_cases, test_set)