https://github.com/hgomersall/pyFFTW
Raw File
Tip revision: 159c4adbe57345fa49f1bfae35ce0afe06188394 authored by Henry Gomersall on 06 February 2023, 14:55:02 UTC
[FIX] removed development debugging output.
Tip revision: 159c4ad
test_pyfftw_builders.py
# Copyright 2015 Knowledge Economy Developments Ltd
#
# Henry Gomersall
# heng@kedevelopments.co.uk
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#

from pyfftw import builders, empty_aligned, byte_align, FFTW
from pyfftw import _supported_nptypes_complex, _supported_nptypes_real
from pyfftw.builders import _utils as utils
from .test_pyfftw_base import run_test_suites, require
from ._get_default_args import get_default_args

import unittest
import numpy
import numpy as np
# import the numpy fft routines having the rfft normalization bug fix
from .test_pyfftw_numpy_interface import np_fft, _numpy_fft_has_norm_kwarg
import copy
import warnings
warnings.filterwarnings('always')

complex_dtypes = _supported_nptypes_complex
real_dtypes = _supported_nptypes_real

def make_complex_data(shape, dtype):
    ar, ai = numpy.random.randn(2, *shape).astype(dtype)
    return ar + 1j*ai

def make_real_data(shape, dtype):
    return numpy.random.randn(*shape).astype(dtype)


input_dtypes = {
    'complex': (complex_dtypes, make_complex_data),
    'r2c': (real_dtypes, make_real_data),
    'c2r': (complex_dtypes, make_complex_data)}

output_dtypes = {
    'complex': complex_dtypes,
    'r2c': complex_dtypes,
    'c2r': real_dtypes}

functions = {
    'fft': 'complex',
    'ifft': 'complex',
    'rfft': 'r2c',
    'irfft': 'c2r',
    'rfftn': 'r2c',
    'irfftn': 'c2r',
    'rfft2': 'r2c',
    'irfft2': 'c2r',
    'fft2': 'complex',
    'ifft2': 'complex',
    'fftn': 'complex',
    'ifftn': 'complex'}


class BuildersTestFFT(unittest.TestCase):

    func = 'fft'
    axes_kw = 'axis'
    if numpy.__version__ >= '1.20.0':
        test_shapes = (
            ((100,), {}),
            ((128, 64), {'axis': 0}),
            ((128, 32), {'axis': -1}),
            ((59, 100), {}),
            ((32, 32, 4), {'axis': 1}),
            ((32, 32, 4), {'axis': 1, 'norm': 'ortho'}),
            ((32, 32, 4), {'axis': 1, 'norm': None}),
            ((32, 32, 4), {'axis': 1, 'norm': 'backward'}),
            ((32, 32, 4), {'axis': 1, 'norm': 'forward'}),
            ((64, 128, 16), {}),
            )
    else:
        test_shapes = (
            ((100,), {}),
            ((128, 64), {'axis': 0}),
            ((128, 32), {'axis': -1}),
            ((59, 100), {}),
            ((32, 32, 4), {'axis': 1}),
            ((32, 32, 4), {'axis': 1, 'norm': 'ortho'}),
            ((32, 32, 4), {'axis': 1, 'norm': None}),
            ((64, 128, 16), {}),
            )

    # invalid_s_shapes is:
    # (size, invalid_args, error_type, error_string)
    invalid_args = (
            ((100,), ((100, 200),), TypeError, ''),
            ((100, 200), ((100, 200),), TypeError, ''),
            ((100,), (100, (-2, -1)), TypeError, ''),
            ((100,), (100, -20), IndexError, ''))

    realinv = False
    has_norm_kwarg = _numpy_fft_has_norm_kwarg()

    def __init__(self, *args, **kwargs):

        super(BuildersTestFFT, self).__init__(*args, **kwargs)

        if not hasattr(self, 'assertRaisesRegex'):
            self.assertRaisesRegex = self.assertRaisesRegexp

    @property
    def test_data(self):
        for test_shape, kwargs in self.test_shapes:
            axes = self.axes_from_kwargs(kwargs)
            s = self.s_from_kwargs(test_shape, kwargs)

            if not self.has_norm_kwarg and 'norm' in kwargs:
                kwargs.pop('norm')

            if self.realinv:
                test_shape = list(test_shape)
                test_shape[axes[-1]] = test_shape[axes[-1]]//2 + 1
                test_shape = tuple(test_shape)

            yield test_shape, s, kwargs

    def validate_pyfftw_object(self, array_type, test_shape, dtype,
            s, kwargs):

        input_array = array_type(test_shape, dtype)

        # Use char because of potential MSVC related bug.
        if input_array.dtype.char == np.dtype('clongdouble').char:
            np_input_array = numpy.complex128(input_array)

        elif input_array.dtype.char == np.dtype('longdouble').char:
            np_input_array = numpy.float64(input_array)

        else:
            np_input_array = input_array

        with warnings.catch_warnings(record=True) as w:
            # We catch the warnings so as to pick up on when
            # a complex array is turned into a real array

            FFTW_object = getattr(builders, self.func)(
                    input_array.copy(), s, **kwargs)

            # We run FFT twice to check two operations don't
            # yield different results (which they might if
            # the state is buggered up).
            output_array = FFTW_object(input_array.copy())
            output_array_2 = FFTW_object(input_array.copy())

            if 'axes' in kwargs:
                axes = {'axes': kwargs['axes']}
            elif 'axis' in kwargs:
                axes = {'axis': kwargs['axis']}
            else:
                axes = {}

            if self.has_norm_kwarg and 'norm' in kwargs:
                axes['norm'] = kwargs['norm']

            test_out_array = getattr(np_fft, self.func)(
                    np_input_array.copy(), s, **axes)

            if (functions[self.func] == 'r2c'):
                if numpy.iscomplexobj(input_array):
                    if len(w) > 0:
                        # Make sure a warning is raised
                        self.assertIs(
                                w[-1].category, numpy.ComplexWarning)

        self.assertTrue(
                numpy.allclose(output_array, test_out_array,
                    rtol=1e-2, atol=1e-4))

        self.assertTrue(
                numpy.allclose(output_array_2, test_out_array,
                    rtol=1e-2, atol=1e-4))

        return FFTW_object

    def axes_from_kwargs(self, kwargs):

        default_args = get_default_args(getattr(builders, self.func))

        if 'axis' in kwargs:
            axes = (kwargs['axis'],)

        elif 'axes' in kwargs:
            axes = kwargs['axes']
            if axes is None:
                axes = default_args['axes']

        else:
            if 'axis' in default_args:
                # default 1D
                axes = (default_args['axis'],)
            else:
                # default nD
                axes = default_args['axes']

        if axes is None:
            axes = (-1,)

        return axes

    def s_from_kwargs(self, test_shape, kwargs):
        ''' Return either a scalar s or a tuple depending on
        whether axis or axes is specified
        '''
        default_args = get_default_args(getattr(builders, self.func))

        if 'axis' in kwargs:
            s = test_shape[kwargs['axis']]

        elif 'axes' in kwargs:
            axes = kwargs['axes']
            if axes is not None:
                s = []
                for each_axis in axes:
                    s.append(test_shape[each_axis])
            else:
                # default nD
                s = []
                try:
                    for each_axis in default_args['axes']:
                        s.append(test_shape[each_axis])
                except TypeError:
                    s = [test_shape[-1]]

        else:
            if 'axis' in default_args:
                # default 1D
                s = test_shape[default_args['axis']]
            else:
                # default nD
                s = []
                try:
                    for each_axis in default_args['axes']:
                        s.append(test_shape[each_axis])
                except TypeError:
                    s = None

        return s

    def test_valid(self):
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:
                s = None

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(type(FFTW_object) == FFTW)

    def test_output_dtype_correct(self):
        '''The output dtype should be correct given the input dtype.

        It was noted that this is a particular problem on windows 64
        due longdouble being mapped to double, but the dtype().char attribute
        still being different.
        '''
        inp_dtype_tuple = input_dtypes[functions[self.func]]
        output_dtype_tuple = output_dtypes[functions[self.func]]

        for input_dtype, output_dtype in zip(inp_dtype_tuple[0],
                                             output_dtype_tuple):

            for test_shape, s, kwargs in self.test_data:
                s = None

                FFTW_object = self.validate_pyfftw_object(inp_dtype_tuple[1],
                        test_shape, input_dtype, s, kwargs)

                self.assertTrue(
                    FFTW_object.output_array.dtype.char ==
                    np.dtype(output_dtype).char)

    def test_fail_on_invalid_s_or_axes(self):
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:

            for test_shape, args, exception, e_str in self.invalid_args:
                input_array = dtype_tuple[1](test_shape, dtype)

                self.assertRaisesRegex(exception, e_str,
                        getattr(builders, self.func),
                        *((input_array,) + args))


    def test_same_sized_s(self):
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(type(FFTW_object) == FFTW)

    def test_bigger_s_overwrite_input(self):
        '''Test that FFTWWrapper deals with a destroyed input properly.
        '''
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                try:
                    for each_axis, length in enumerate(s):
                        s[each_axis] += 2
                except TypeError:
                    s += 2

                _kwargs = kwargs.copy()

                if self.func not in ('irfft2', 'irfftn'):
                    # They implicitly overwrite the input anyway
                    _kwargs['overwrite_input'] = True

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, _kwargs)

                self.assertTrue(
                        type(FFTW_object) == utils._FFTWWrapper)

    def test_bigger_s(self):
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                try:
                    for each_axis, length in enumerate(s):
                        s[each_axis] += 2
                except TypeError:
                    s += 2

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(
                        type(FFTW_object) == utils._FFTWWrapper)

    def test_smaller_s(self):
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                try:
                    for each_axis, length in enumerate(s):
                        s[each_axis] -= 2
                except TypeError:
                    s -= 2

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(
                        type(FFTW_object) == utils._FFTWWrapper)

    def test_bigger_and_smaller_s(self):
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            i = -1
            for test_shape, s, kwargs in self.test_data:

                try:
                    for each_axis, length in enumerate(s):
                        s[each_axis] += i * 2
                        i *= i
                except TypeError:
                    s += i * 2
                    i *= i

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(
                        type(FFTW_object) == utils._FFTWWrapper)

    def test_auto_contiguous_input(self):
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:
                _kwargs = kwargs.copy()
                s1 = None
                s2 = copy.copy(s)
                try:
                    for each_axis, length in enumerate(s):
                        s2[each_axis] += 2
                except TypeError:
                    s2 += 2

                _test_shape = []
                slices = []
                for each_dim in test_shape:
                    _test_shape.append(each_dim*2)
                    slices.append(slice(None, None, 2))
                slices = tuple(slices)

                input_array = dtype_tuple[1](_test_shape, dtype)[slices]
                # check the input is non contiguous
                self.assertFalse(input_array.flags['C_CONTIGUOUS'] or
                    input_array.flags['F_CONTIGUOUS'])


                # Firstly check the non-contiguous case (for both
                # FFTW and _FFTWWrapper)
                _kwargs['auto_contiguous'] = False

                # We also need to make sure we're not copying due
                # to a trivial misalignment
                _kwargs['auto_align_input'] = False

                FFTW_object = getattr(builders, self.func)(
                        input_array, s1, **_kwargs)

                internal_input_array = FFTW_object.input_array
                flags = internal_input_array.flags
                self.assertTrue(input_array is internal_input_array)
                self.assertFalse(flags['C_CONTIGUOUS'] or
                    flags['F_CONTIGUOUS'])

                FFTW_object = getattr(builders, self.func)(
                        input_array, s2, **_kwargs)

                internal_input_array = FFTW_object.input_array
                flags = internal_input_array.flags
                # We actually expect the _FFTWWrapper to be C_CONTIGUOUS
                self.assertTrue(flags['C_CONTIGUOUS'])

                # Now for the contiguous case (for both
                # FFTW and _FFTWWrapper)
                _kwargs['auto_contiguous'] = True
                FFTW_object = getattr(builders, self.func)(
                        input_array, s1, **_kwargs)

                internal_input_array = FFTW_object.input_array
                flags = internal_input_array.flags
                self.assertTrue(flags['C_CONTIGUOUS'] or
                    flags['F_CONTIGUOUS'])

                FFTW_object = getattr(builders, self.func)(
                        input_array, s2, **_kwargs)

                internal_input_array = FFTW_object.input_array
                flags = internal_input_array.flags
                # as above
                self.assertTrue(flags['C_CONTIGUOUS'])


    def test_auto_align_input(self):
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:
                _kwargs = kwargs.copy()
                s1 = None
                s2 = copy.copy(s)
                try:
                    for each_axis, length in enumerate(s):
                        s2[each_axis] += 2
                except TypeError:
                    s2 += 2

                input_array = dtype_tuple[1](test_shape, dtype)

                # Firstly check the unaligned case (for both
                # FFTW and _FFTWWrapper)
                _kwargs['auto_align_input'] = False
                FFTW_object = getattr(builders, self.func)(
                        input_array.copy(), s1, **_kwargs)

                self.assertFalse(FFTW_object.simd_aligned)

                FFTW_object = getattr(builders, self.func)(
                        input_array.copy(), s2, **_kwargs)

                self.assertFalse(FFTW_object.simd_aligned)

                # Now for the aligned case (for both
                # FFTW and _FFTWWrapper)
                _kwargs['auto_align_input'] = True
                FFTW_object = getattr(builders, self.func)(
                        input_array.copy(), s1, **_kwargs)

                self.assertTrue(FFTW_object.simd_aligned)

                self.assertTrue('FFTW_UNALIGNED' not in FFTW_object.flags)
                FFTW_object = getattr(builders, self.func)(
                        input_array.copy(), s2, **_kwargs)

                self.assertTrue(FFTW_object.simd_aligned)

                self.assertTrue('FFTW_UNALIGNED' not in FFTW_object.flags)

    def test_dtype_coercian(self):
        # Make sure we input a dtype that needs to be coerced
        if functions[self.func] == 'r2c':
            dtype_tuple = input_dtypes['complex']
        else:
            dtype_tuple = input_dtypes['r2c']

        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:
                s = None

                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                self.assertTrue(type(FFTW_object) == FFTW)

    def test_persistent_padding(self):
        '''Test to confirm the padding it not touched after creation.
        '''
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                n_add = 2
                # these slicers get the padding
                # from the internal input array
                padding_slicer = [slice(None)] * len(test_shape)
                axes = self.axes_from_kwargs(kwargs)
                try:
                    for each_axis, length in enumerate(s):
                        s[each_axis] += n_add
                        padding_slicer[axes[each_axis]] = (
                                slice(s[each_axis], None))

                except TypeError:
                    s += n_add
                    padding_slicer[axes[0]] = slice(s, None)
                padding_slicer = tuple(padding_slicer)
                # Get a valid object
                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                internal_array = FFTW_object.input_array
                padding = internal_array[padding_slicer]

                # Fill the padding with garbage
                initial_padding = dtype_tuple[1](padding.shape, dtype)

                padding[:] = initial_padding

                # Now confirm that nothing is done to the padding
                FFTW_object()

                final_padding = FFTW_object.input_array[padding_slicer]

                self.assertTrue(numpy.all(final_padding == initial_padding))

    def test_planner_effort(self):
        '''Test the planner effort arg
        '''
        dtype_tuple = input_dtypes[functions[self.func]]
        test_shape = (16,)

        for dtype in dtype_tuple[0]:
            s = None
            if self.axes_kw == 'axis':
                kwargs = {'axis': -1}
            else:
                kwargs = {'axes': (-1,)}

            for each_effort in ('FFTW_ESTIMATE', 'FFTW_MEASURE',
                    'FFTW_PATIENT', 'FFTW_EXHAUSTIVE'):

                kwargs['planner_effort'] = each_effort

                FFTW_object = self.validate_pyfftw_object(
                        dtype_tuple[1], test_shape, dtype, s, kwargs)

                self.assertTrue(each_effort in FFTW_object.flags)

            kwargs['planner_effort'] = 'garbage'

            self.assertRaisesRegex(ValueError, 'Invalid planner effort',
                    self.validate_pyfftw_object,
                    *(dtype_tuple[1], test_shape, dtype, s, kwargs))

    def test_threads_arg(self):
        '''Test the threads argument
        '''
        dtype_tuple = input_dtypes[functions[self.func]]
        test_shape = (16,)

        for dtype in dtype_tuple[0]:
            s = None
            if self.axes_kw == 'axis':
                kwargs = {'axis': -1}
            else:
                kwargs = {'axes': (-1,)}

            kwargs['threads'] = 2

            # Should just work
            FFTW_object = self.validate_pyfftw_object(
                    dtype_tuple[1], test_shape, dtype, s, kwargs)

            kwargs['threads'] = 'bleh'

            # Should not work
            self.assertRaises(TypeError,
                    self.validate_pyfftw_object,
                    *(dtype_tuple[1], test_shape, dtype, s, kwargs))


    def test_overwrite_input(self):
        '''Test the overwrite_input flag
        '''
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:
            for test_shape, s, _kwargs in self.test_data:
                s = None

                kwargs = _kwargs.copy()
                FFTW_object = self.validate_pyfftw_object(dtype_tuple[1],
                        test_shape, dtype, s, kwargs)

                if self.func not in ('irfft2', 'irfftn'):
                    self.assertTrue(
                            'FFTW_DESTROY_INPUT' not in FFTW_object.flags)

                    kwargs['overwrite_input'] = True

                    FFTW_object = self.validate_pyfftw_object(
                            dtype_tuple[1], test_shape, dtype, s, kwargs)

                self.assertTrue('FFTW_DESTROY_INPUT' in FFTW_object.flags)


    def test_input_maintained(self):
        '''Test to make sure the input is maintained
        '''
        dtype_tuple = input_dtypes[functions[self.func]]
        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:

                input_array = dtype_tuple[1](test_shape, dtype)

                FFTW_object = getattr(
                        builders, self.func)(input_array, s, **kwargs)

                final_input_array = FFTW_object.input_array

                self.assertTrue(
                        numpy.alltrue(input_array == final_input_array))

    def test_avoid_copy(self):
        '''Test the avoid_copy flag
        '''
        dtype_tuple = input_dtypes[functions[self.func]]

        for dtype in dtype_tuple[0]:
            for test_shape, s, kwargs in self.test_data:
                _kwargs = kwargs.copy()

                _kwargs['avoid_copy'] = True

                s2 = copy.copy(s)
                try:
                    for each_axis, length in enumerate(s):
                        s2[each_axis] += 2
                except TypeError:
                    s2 += 2

                input_array = dtype_tuple[1](test_shape, dtype)

                self.assertRaisesRegex(ValueError,
                        'Cannot avoid copy.*transform shape.*',
                        getattr(builders, self.func),
                        input_array, s2, **_kwargs)

                non_contiguous_shape = [
                        each_dim * 2 for each_dim in test_shape]
                non_contiguous_slices = tuple(
                        [slice(None, None, 2)] * len(test_shape))

                misaligned_input_array = dtype_tuple[1](
                        non_contiguous_shape, dtype)[non_contiguous_slices]

                self.assertRaisesRegex(ValueError,
                        'Cannot avoid copy.*not contiguous.*',
                        getattr(builders, self.func),
                        misaligned_input_array, s, **_kwargs)

                # Offset by one from 16 byte aligned to guarantee it's not
                # 16 byte aligned
                _input_array = empty_aligned(
                        numpy.prod(test_shape)*input_array.itemsize+1,
                        dtype='int8', n=16)

                misaligned_input_array = _input_array[1:].view(
                         dtype=input_array.dtype).reshape(*test_shape)

                self.assertRaisesRegex(ValueError,
                        'Cannot avoid copy.*not aligned.*',
                        getattr(builders, self.func),
                        misaligned_input_array, s, **_kwargs)

                _input_array = byte_align(input_array.copy())
                FFTW_object = getattr(builders, self.func)(
                        _input_array, s, **_kwargs)

                # A catch all to make sure the internal array
                # is not a copy
                self.assertTrue(FFTW_object.input_array is
                        _input_array)


class BuildersTestIFFT(BuildersTestFFT):
    func = 'ifft'

class BuildersTestRFFT(BuildersTestFFT):
    func = 'rfft'

class BuildersTestIRFFT(BuildersTestFFT):
    func = 'irfft'
    realinv = True

class BuildersTestFFT2(BuildersTestFFT):
    axes_kw = 'axes'
    func = 'ifft2'
    if numpy.__version__ >= '1.20.0':
        test_shapes = (
            ((128, 64), {'axes': None}),
            ((128, 32), {'axes': None}),
            ((128, 32, 4), {'axes': (0, 2)}),
            ((59, 100), {'axes': (-2, -1)}),
            ((59, 100), {'axes': (-2, -1), 'norm': 'ortho'}),
            ((59, 100), {'axes': (-2, -1), 'norm': None}),
            ((59, 100), {'axes': (-2, -1), 'norm': 'backward'}),
            ((59, 100), {'axes': (-2, -1), 'norm': 'forward'}),
            ((64, 128, 16), {'axes': (0, 2)}),
            ((4, 6, 8, 4), {'axes': (0, 3)}),
            )
    else:
        test_shapes = (
            ((128, 64), {'axes': None}),
            ((128, 32), {'axes': None}),
            ((128, 32, 4), {'axes': (0, 2)}),
            ((59, 100), {'axes': (-2, -1)}),
            ((59, 100), {'axes': (-2, -1), 'norm': 'ortho'}),
            ((59, 100), {'axes': (-2, -1), 'norm': None}),
            ((64, 128, 16), {'axes': (0, 2)}),
            ((4, 6, 8, 4), {'axes': (0, 3)}),
            )

    invalid_args = (
            ((100,), ((100, 200),), ValueError, 'Shape error'),
            ((100, 200), ((100, 200, 100),), ValueError, 'Shape error'),
            ((100,), ((100, 200), (-3, -2, -1)), ValueError, 'Shape error'),
            ((100, 200), (100, -1), TypeError, ''),
            ((100, 200), ((100, 200), (-3, -2)), IndexError, 'Invalid axes'),
            ((100, 200), ((100,), (-3,)), IndexError, 'Invalid axes'))


class BuildersTestIFFT2(BuildersTestFFT2):
    func = 'ifft2'

class BuildersTestRFFT2(BuildersTestFFT2):
    func = 'rfft2'

class BuildersTestIRFFT2(BuildersTestFFT2):
    func = 'irfft2'
    realinv = True

class BuildersTestFFTN(BuildersTestFFT2):
    func = 'ifftn'
    if numpy.__version__ >= '1.20.0':
        test_shapes = (
            ((128, 32, 4), {'axes': None}),
            ((64, 128, 16), {'axes': (0, 1, 2)}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1)}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': None}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': 'backward'}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': 'forward'}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
            )
    else:
        test_shapes = (
            ((128, 32, 4), {'axes': None}),
            ((64, 128, 16), {'axes': (0, 1, 2)}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1)}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': 'ortho'}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1), 'norm': None}),
            ((4, 6, 8, 4), {'axes': (0, 3, 1, 2)}),
            )

class BuildersTestIFFTN(BuildersTestFFTN):
    func = 'ifftn'

class BuildersTestRFFTN(BuildersTestFFTN):
    func = 'rfftn'

class BuildersTestIRFFTN(BuildersTestFFTN):
    func = 'irfftn'
    realinv = True


class BuildersTestFFTWWrapper(unittest.TestCase):
    '''This basically reimplements the FFTW.__call__ tests, with
    a few tweaks.
    '''

    def __init__(self, *args, **kwargs):

        super(BuildersTestFFTWWrapper, self).__init__(*args, **kwargs)

        if not hasattr(self, 'assertRaisesRegex'):
            self.assertRaisesRegex = self.assertRaisesRegexp

    def setUp(self):

        require(self, '64')

        self.input_array_slicer = tuple([slice(None), slice(256)])
        self.FFTW_array_slicer = tuple([slice(128), slice(None)])

        self.input_array = empty_aligned((128, 512), dtype='complex128')
        self.output_array = empty_aligned((256, 256), dtype='complex128')

        self.internal_array = empty_aligned((256, 256), dtype='complex128')

        self.fft = utils._FFTWWrapper(self.internal_array,
                self.output_array,
                input_array_slicer=self.input_array_slicer,
                FFTW_array_slicer=self.FFTW_array_slicer)

        self.input_array[:] = (numpy.random.randn(*self.input_array.shape)
                + 1j*numpy.random.randn(*self.input_array.shape))

        self.internal_array[:] = 0
        self.internal_array[self.FFTW_array_slicer] = (
                self.input_array[self.input_array_slicer])

    def update_arrays(self, input_array, output_array):
        '''Does what the internal update arrays does for an FFTW
        object but with a reslicing.
        '''
        internal_input_array = self.fft.input_array
        internal_output_array = self.fft.output_array

        internal_input_array[self.FFTW_array_slicer] = (
                input_array[self.input_array_slicer])

        self.fft(output_array=output_array)

    def test_call(self):
        '''Test a call to an instance of the class.
        '''

        self.input_array[:] = (numpy.random.randn(*self.input_array.shape)
                + 1j*numpy.random.randn(*self.input_array.shape))

        output_array = self.fft()

        self.assertTrue(numpy.alltrue(output_array == self.output_array))


    def test_call_with_positional_input_update(self):
        '''Test the class call with a positional input update.
        '''

        input_array = byte_align(
                (numpy.random.randn(*self.input_array.shape)
                    + 1j*numpy.random.randn(*self.input_array.shape)))

        output_array = self.fft(
                byte_align(input_array.copy())).copy()

        self.update_arrays(input_array, self.output_array)
        self.fft.execute()

        self.assertTrue(numpy.alltrue(output_array == self.output_array))


    def test_call_with_keyword_input_update(self):
        '''Test the class call with a keyword input update.
        '''
        input_array = byte_align(
                numpy.random.randn(*self.input_array.shape)
                    + 1j*numpy.random.randn(*self.input_array.shape))

        output_array = self.fft(
            input_array=byte_align(input_array.copy())).copy()

        self.update_arrays(input_array, self.output_array)
        self.fft.execute()

        self.assertTrue(numpy.alltrue(output_array == self.output_array))


    def test_call_with_keyword_output_update(self):
        '''Test the class call with a keyword output update.
        '''
        output_array = byte_align(
            (numpy.random.randn(*self.output_array.shape)
                + 1j*numpy.random.randn(*self.output_array.shape)))

        returned_output_array = self.fft(
                output_array=byte_align(output_array.copy())).copy()


        self.update_arrays(self.input_array, output_array)
        self.fft.execute()

        self.assertTrue(
                numpy.alltrue(returned_output_array == output_array))

    def test_call_with_positional_updates(self):
        '''Test the class call with a positional array updates.
        '''

        input_array = byte_align((numpy.random.randn(*self.input_array.shape)
            + 1j*numpy.random.randn(*self.input_array.shape)))

        output_array = byte_align((numpy.random.randn(*self.output_array.shape)
            + 1j*numpy.random.randn(*self.output_array.shape)))

        returned_output_array = self.fft(
            byte_align(input_array.copy()),
            byte_align(output_array.copy())).copy()

        self.update_arrays(input_array, output_array)
        self.fft.execute()

        self.assertTrue(numpy.alltrue(returned_output_array == output_array))

    def test_call_with_keyword_updates(self):
        '''Test the class call with a positional output update.
        '''

        input_array = byte_align(
                (numpy.random.randn(*self.input_array.shape)
                    + 1j*numpy.random.randn(*self.input_array.shape)))

        output_array = byte_align(
                (numpy.random.randn(*self.output_array.shape)
                    + 1j*numpy.random.randn(*self.output_array.shape)))

        returned_output_array = self.fft(
                output_array=byte_align(output_array.copy()),
                input_array=byte_align(input_array.copy())).copy()

        self.update_arrays(input_array, output_array)
        self.fft.execute()

        self.assertTrue(numpy.alltrue(returned_output_array == output_array))

    def test_call_with_different_input_dtype(self):
        '''Test the class call with an array with a different input dtype
        '''
        input_array = byte_align(numpy.complex64(
                numpy.random.randn(*self.input_array.shape)
                + 1j*numpy.random.randn(*self.input_array.shape)))

        output_array = self.fft(byte_align(input_array.copy())).copy()

        _input_array = numpy.asarray(input_array,
                dtype=self.input_array.dtype)

        self.update_arrays(_input_array, self.output_array)
        self.fft.execute()

        self.assertTrue(numpy.alltrue(output_array == self.output_array))

    def test_call_with_list_input(self):
        '''Test the class call with a list rather than an array
        '''

        output_array = self.fft().copy()

        test_output_array = self.fft(self.input_array.tolist()).copy()

        self.assertTrue(numpy.alltrue(output_array == test_output_array))


    def test_call_with_invalid_update(self):
        '''Test the class call with an invalid update.
        '''

        new_shape = self.input_array.shape + (2, )
        invalid_array = (numpy.random.randn(*new_shape)
                + 1j*numpy.random.randn(*new_shape))

        self.assertRaises(ValueError, self.fft,
                *(),
                **{'output_array':invalid_array})

        self.assertRaises(ValueError, self.fft,
                *(),
                **{'input_array':invalid_array})


    def test_call_with_invalid_output_striding(self):
        '''Test the class call with an invalid strided output update.
        '''
        # Add an extra dimension to bugger up the striding
        new_shape = self.output_array.shape + (2,)
        output_array = byte_align(numpy.random.randn(*new_shape)
                + 1j*numpy.random.randn(*new_shape))

        self.assertRaisesRegex(ValueError, 'Invalid output striding',
                self.fft, **{'output_array': output_array[:,:,1]})

    def test_call_with_different_striding(self):
        '''Test the input update with different strides to internal array.
        '''
        input_array_shape = self.input_array.shape + (2,)
        internal_array_shape = self.internal_array.shape

        internal_array = byte_align(
                numpy.random.randn(*internal_array_shape)
                + 1j*numpy.random.randn(*internal_array_shape))

        fft =  utils._FFTWWrapper(internal_array, self.output_array,
                input_array_slicer=self.input_array_slicer,
                FFTW_array_slicer=self.FFTW_array_slicer)

        test_output_array = fft().copy()

        new_input_array = empty_aligned(input_array_shape,
                                        dtype=internal_array.dtype)
        new_input_array[:] = 0

        new_input_array[:,:,0][self.input_array_slicer] = (
                internal_array[self.FFTW_array_slicer])

        new_output = fft(new_input_array[:,:,0]).copy()

        # Test the test!
        self.assertTrue(
                new_input_array[:,:,0].strides != internal_array.strides)

        self.assertTrue(numpy.alltrue(test_output_array == new_output))

    def test_call_with_copy_with_missized_array_error(self):
        '''Force an input copy with a missized array.
        '''
        shape = list(self.input_array.shape + (2,))
        shape[0] += 1

        input_array = byte_align(numpy.random.randn(*shape)
                + 1j*numpy.random.randn(*shape))

        self.assertRaisesRegex(ValueError, 'Invalid input shape',
                self.fft, **{'input_array': input_array[:,:,0]})

    def test_call_with_normalisation_on(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                direction='FFTW_BACKWARD',
                input_array_slicer=slice(None),
                FFTW_array_slicer=slice(None))

        self.fft(normalise_idft=True) # Shouldn't make any difference
        ifft(normalise_idft=True)

        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_with_normalisation_off(self):

        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                direction='FFTW_BACKWARD',
                input_array_slicer=slice(None),
                FFTW_array_slicer=slice(None))

        self.fft(normalise_idft=True) # Shouldn't make any difference
        ifft(normalise_idft=False)

        _input_array /= ifft.N

        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_with_normalisation_default(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                direction='FFTW_BACKWARD',
                input_array_slicer=slice(None),
                FFTW_array_slicer=slice(None))

        self.fft()
        ifft()

        # Scaling is performed by default
        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_norm_ortho(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        normdict = utils._norm_args("ortho")

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                                  direction='FFTW_BACKWARD',
                                  input_array_slicer=slice(None),
                                  FFTW_array_slicer=slice(None),
                                  normalise_idft=normdict["normalise_idft"],
                                  ortho=normdict["ortho"])

        self.fft(normalise_idft=normdict["normalise_idft"],
                 ortho=normdict["ortho"])
        ifft()
        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_norm_backward(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        normdict = utils._norm_args("backward")

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                                  direction='FFTW_BACKWARD',
                                  input_array_slicer=slice(None),
                                  FFTW_array_slicer=slice(None),
                                  normalise_idft=normdict["normalise_idft"],
                                  ortho=normdict["ortho"])

        self.fft(normalise_idft=normdict["normalise_idft"],
                 ortho=normdict["ortho"])
        ifft()
        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_norm_none(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        normdict = utils._norm_args(None)

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                                  direction='FFTW_BACKWARD',
                                  input_array_slicer=slice(None),
                                  FFTW_array_slicer=slice(None),
                                  normalise_idft=normdict["normalise_idft"],
                                  ortho=normdict["ortho"])

        self.fft(normalise_idft=normdict["normalise_idft"],
                 ortho=normdict["ortho"])
        ifft()
        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))

    def test_call_norm_forward(self):
        _input_array = empty_aligned(self.internal_array.shape,
                                     dtype='complex128')

        normdict = utils._norm_args("forward")

        ifft = utils._FFTWWrapper(self.output_array, _input_array,
                                  direction='FFTW_BACKWARD',
                                  input_array_slicer=slice(None),
                                  FFTW_array_slicer=slice(None),
                                  normalise_idft=normdict["normalise_idft"],
                                  ortho=normdict["ortho"])

        self.fft(normalise_idft=normdict["normalise_idft"],
                 ortho=normdict["ortho"])
        ifft()
        self.assertTrue(numpy.allclose(
            self.input_array[self.input_array_slicer],
            _input_array[self.FFTW_array_slicer]))


class BuildersTestUtilities(unittest.TestCase):

    def __init__(self, *args, **kwargs):

        super(BuildersTestUtilities, self).__init__(*args, **kwargs)

        if not hasattr(self, 'assertRaisesRegex'):
            self.assertRaisesRegex = self.assertRaisesRegexp

    def test_setup_input_slicers(self):
        inputs = (
                ((4, 5), (4, 5)),
                ((4, 4), (3, 5)),
                ((4, 5), (3, 5)),
                )

        outputs = (
                ((slice(0, 4), slice(0, 5)), (slice(None), slice(None))),
                ((slice(0, 3), slice(0, 4)), (slice(None), slice(0, 4))),
                ((slice(0, 3), slice(0, 5)), (slice(None), slice(None))),
                )

        for _input, _output in zip(inputs, outputs):
            self.assertEqual(
                    utils._setup_input_slicers(*_input),
                    _output)



    def test_compute_array_shapes(self):
        # inputs are:
        # (a.shape, s, axes, inverse, real)
        inputs = (
                ((4, 5), (4, 5), (-2, -1), False, False),
                ((4, 5), (4, 5), (-1, -2), False, False),
                ((4, 5), (4, 5), (-1, -2), True, False),
                ((4, 5), (4, 5), (-1, -2), True, True),
                ((4, 5), (4, 5), (-2, -1), True, True),
                ((4, 5), (4, 5), (-2, -1), False, True),
                ((4, 5), (4, 5), (-1, -2), False, True),
                ((4, 5, 6), (4, 5), (-2, -1), False, False),
                ((4, 5, 6), (5, 6), (-2, -1), False, False),
                ((4, 5, 6), (3, 5), (-3, -1), False, False),
                ((4, 5, 6), (4, 5), (-2, -1), True, False),
                ((4, 5, 6), (3, 5), (-3, -1), True, False),
                ((4, 5, 6), (4, 5), (-2, -1), True, True),
                ((4, 5, 6), (3, 5), (-3, -1), True, True),
                ((4, 5, 6), (4, 5), (-2, -1), False, True),
                ((4, 5, 6), (3, 5), (-3, -1), False, True),
                )

        outputs = (
                ((4, 5), (4, 5)),
                ((5, 4), (5, 4)),
                ((5, 4), (5, 4)),
                ((3, 4), (5, 4)),
                ((4, 3), (4, 5)),
                ((4, 5), (4, 3)),
                ((5, 4), (3, 4)),
                ((4, 4, 5), (4, 4, 5)),
                ((4, 5, 6), (4, 5, 6)),
                ((3, 5, 5), (3, 5, 5)),
                ((4, 4, 5), (4, 4, 5)),
                ((3, 5, 5), (3, 5, 5)),
                ((4, 4, 3), (4, 4, 5)),
                ((3, 5, 3), (3, 5, 5)),
                ((4, 4, 5), (4, 4, 3)),
                ((3, 5, 5), (3, 5, 3)),
                )

        for _input, output in zip(inputs, outputs):
            shape, s, axes, inverse, real = _input
            a = numpy.empty(shape)

            self.assertEqual(
                    utils._compute_array_shapes(a, s, axes, inverse, real),
                    output)

    def test_compute_array_shapes_invalid_axes(self):

        a = numpy.zeros((3, 4))
        s = (3, 4)
        test_axes = ((1, 2, 3),)

        for each_axes in test_axes:

            args = (a, s, each_axes, False, False)
            self.assertRaisesRegex(IndexError, 'Invalid axes',
                    utils._compute_array_shapes, *args)

    def _call_cook_nd_args(self, arg_tuple):
        a = numpy.zeros(arg_tuple[0])
        args = ('s', 'axes', 'invreal')
        arg_dict = {'a': a}
        for arg_name, arg in zip(args, arg_tuple[1:]):
            if arg is not None:
                arg_dict[arg_name] = arg

        return utils._cook_nd_args(**arg_dict)

    def test_cook_nd_args_normal(self):
        # inputs are (a.shape, s, axes, invreal)
        # None corresponds to no argument
        inputs = (
                ((2, 3), None, (-1,), False),
                ((2, 3), (5, 6), (-2, -1), False),
                ((2, 3), (5, 6), (-1, -2), False),
                ((2, 3), None, (-1, -2), False),
                ((2, 3, 5), (5, 6), (-1, -2), False),
                ((2, 3, 5), (5, 6), None, False),
                ((2, 3, 5), None, (-1, -2), False),
                ((2, 3, 5), None, (-1, -3), False))

        outputs = (
                ((3,), (-1,)),
                ((5, 6), (-2, -1)),
                ((5, 6), (-1, -2)),
                ((3, 2), (-1, -2)),
                ((5, 6), (-1, -2)),
                ((5, 6), (-2, -1)),
                ((5, 3), (-1, -2)),
                ((5, 2), (-1, -3))
                )

        for each_input, each_output in zip(inputs, outputs):
            self.assertEqual(self._call_cook_nd_args(each_input),
                    each_output)

    def test_cook_nd_args_invreal(self):

        # inputs are (a.shape, s, axes, invreal)
        # None corresponds to no argument
        inputs = (
                ((2, 3), None, (-1,), True),
                ((2, 3), (5, 6), (-2, -1), True),
                ((2, 3), (5, 6), (-1, -2), True),
                ((2, 3), None, (-1, -2), True),
                ((2, 3, 5), (5, 6), (-1, -2), True),
                ((2, 3, 5), (5, 6), None, True),
                ((2, 3, 5), None, (-1, -2), True),
                ((2, 3, 5), None, (-1, -3), True))

        outputs = (
                ((4,), (-1,)),
                ((5, 6), (-2, -1)),
                ((5, 6), (-1, -2)),
                ((3, 2), (-1, -2)),
                ((5, 6), (-1, -2)),
                ((5, 6), (-2, -1)),
                ((5, 4), (-1, -2)),
                ((5, 2), (-1, -3))
                )

        for each_input, each_output in zip(inputs, outputs):
            self.assertEqual(self._call_cook_nd_args(each_input),
                    each_output)


    def test_cook_nd_args_invalid_inputs(self):
        # inputs are (a.shape, s, axes, invreal)
        # None corresponds to no argument
        inputs = (
                ((2, 3), (1,), (-1, -2), None),
                ((2, 3), (2, 3, 4), (-3, -2, -1), None),
                )

        # all the inputs should yield an error
        for each_input in inputs:
            self.assertRaisesRegex(ValueError, 'Shape error',
                    self._call_cook_nd_args, *(each_input,))

test_cases = (
        BuildersTestFFTWWrapper,
        BuildersTestUtilities,
        BuildersTestFFT,
        BuildersTestIFFT,
        BuildersTestRFFT,
        BuildersTestIRFFT,
        BuildersTestFFT2,
        BuildersTestIFFT2,
        BuildersTestRFFT2,
        BuildersTestIRFFT2,
        BuildersTestFFTN,
        BuildersTestIFFTN,
        BuildersTestRFFTN,
        BuildersTestIRFFTN)

#test_set = {'BuildersTestRFFTN': ['test_dtype_coercian']}
test_set = None


if __name__ == '__main__':

    run_test_suites(test_cases, test_set)
back to top