https://github.com/mikgroup/sigpy
Raw File
Tip revision: 66c0a646f74e6c130835689a93099b9909f33d6f authored by frankong on 06 July 2018, 23:31:34 UTC
removed non-ASCII character.
Tip revision: 66c0a64
conv.py
import numpy as np
from sigpy import fft, util, config


__all__ = ['convolve', 'correlate',
           'cudnn_convolve', 'cudnn_convolve_backward_filter',
           'cudnn_convolve_backward_data']


def convolve(input1, input2, axes=None, mode='full'):
    """Multi-dimensional convolution.

    Args:
        input1 (array): Input 1.
        input2 (array): Input 2.
        axes (None or array of ints): Axes to perform convolution.
        mode (str): {'full', 'valid'}

    Returns:
        array: Convolved result of shape input1.shape + input2.shape - 1
            if mode='full', and input1.shape - input2.shape + 1 otherwise.

    """

    util._check_same_dtype(input1, input2)
    axes = util._normalize_axes(axes, max(input1.ndim, input2.ndim))
    tshape1, tshape2, toshape, oshape, shift, scale = _get_convolve_shapes(input1.shape,
                                                                           input2.shape,
                                                                           axes, mode)
    dtype = input1.dtype
    max_ndim = len(oshape)
    device = util.get_device(input1)
    xp = device.xp

    with device:
        input1_pad = util.resize(input1, tshape1, oshift=[0] * max_ndim)
        input2_pad = util.resize(input2, tshape2, oshift=[0] * max_ndim)

        if np.issubdtype(dtype, np.floating):
            input1_fft = xp.fft.rfftn(input1_pad, axes=axes, norm='ortho')
            input2_fft = xp.fft.rfftn(input2_pad, axes=axes, norm='ortho')
            output = xp.fft.irfftn(input1_fft * input2_fft, [toshape[a] for a in axes],
                                   axes=axes, norm='ortho').astype(dtype)
        else:
            input1_fft = fft.fft(input1_pad, axes=axes, center=False)
            input2_fft = fft.fft(input2_pad, axes=axes, center=False)
            output = fft.ifft(input1_fft * input2_fft, axes=axes, center=False)

        if mode == 'valid':
            output = util.resize(output, oshape, ishift=shift)
        output *= scale

        return output


def correlate(input1, input2, axes=None, mode='full'):
    """Multi-dimensional cross-correlation.

    Args:
        input1 (array): Input 1.
        input2 (array): Input 2.
        axes (None or array of ints): Axes to perform cross-correlation.
        mode (str): {'full', 'valid'}

    Returns:
        array: Correlated result of shape input1.shape + input2.shape - 1
            if mode='full', and input1.shape - input2.shape + 1 otherwise.

    """

    device = util.get_device(input1)
    xp = device.xp

    with device:
        _, ishape2_exp = util._expand_shapes(input1.shape, input2.shape)
        input2_conj = xp.conj(
            util.flip(input2.reshape(ishape2_exp), axes=axes))
        output = convolve(input1, input2_conj, axes=axes, mode=mode)
        return output


if config.cudnn_enabled:
    import cupy as cp
    from cupy import cudnn
    from cupy.cuda import cudnn as libcudnn

    def _get_cudnn_convolve_y_shape(x_shape, W_shape, mode):
        c_O, c_I = W_shape[:2]
        b = x_shape[0]

        if mode == 'full':
            return (b, c_O) + tuple(m + n - 1 for m, n in zip(x_shape[2:], W_shape[2:]))
        else:
            return (b, c_O) + tuple(m - n + 1 for m, n in zip(x_shape[2:], W_shape[2:]))

    def _get_cudnn_convolve_x_shape(W_shape, y_shape, mode):
        c_O, c_I = W_shape[:2]
        b = y_shape[0]

        if mode == 'full':
            return (b, c_I) + tuple(p - n + 1 for p, n in zip(y_shape[2:], W_shape[2:]))
        else:
            return (b, c_I) + tuple(p + n - 1 for p, n in zip(y_shape[2:], W_shape[2:]))

    def _get_cudnn_convolve_W_shape(x_shape, y_shape, mode):
        c_I = x_shape[1]
        c_O = y_shape[1]
        b = y_shape[0]

        if mode == 'full':
            return (c_O, c_I) + tuple(p - m + 1 for p, m in zip(y_shape[2:], x_shape[2:]))
        else:
            return (c_O, c_I) + tuple(m - p + 1 for p, m in zip(y_shape[2:], x_shape[2:]))

    def cudnn_convolve(
            x, W, mode='full', strides=None,
            workspace_size=1024 * 1024 * 1024,
            conv_mode=libcudnn.CUDNN_CONVOLUTION,
            pref=libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
    ):
        """
        x - (b, c_I, m_1, m_2, ..., m_N)
        W - (c_O, c_I, n_1, n_2, ..., n_N)
        y - (b, c_O, p_1, p_2, ..., p_N)
        """
        assert x.shape[1] == W.shape[1]
        assert x.ndim == W.ndim
        util._check_same_dtype(W, x)

        ndim = W.ndim - 2
        dtype = x.dtype
        device = util.get_device(x)
        xp = device.xp

        if np.issubdtype(dtype, np.complexfloating):
            with device:
                xr = xp.real(x)
                xi = xp.imag(x)
                Wr = xp.real(W)
                Wi = xp.imag(W)

                yr = cudnn_convolve(xr, Wr, mode=mode,
                                    strides=strides, conv_mode=conv_mode)
                yr -= cudnn_convolve(xi, Wi, mode=mode,
                                     strides=strides, conv_mode=conv_mode)
                yi = cudnn_convolve(xr, Wi, mode=mode,
                                    strides=strides, conv_mode=conv_mode)
                yi += cudnn_convolve(xi, Wr, mode=mode,
                                     strides=strides, conv_mode=conv_mode)

                return (yr + 1j * yi).astype(dtype)

        if strides is None:
            strides = (1, ) * ndim

        if mode == 'full':
            pads = tuple(n - 1 for n in W.shape[2:])
        else:
            pads = (0, ) * ndim

        y_shape = _get_cudnn_convolve_y_shape(x.shape, W.shape, mode)

        with device:
            x = cp.ascontiguousarray(x)
            W = cp.ascontiguousarray(W)
            y = util.empty(y_shape, dtype=dtype, device=device)

            handle = cudnn.get_handle()
            x_desc = cudnn.create_tensor_descriptor(x)
            y_desc = cudnn.create_tensor_descriptor(y)
            W_desc = cudnn.create_filter_descriptor(W)

            conv_desc = cudnn.create_convolution_descriptor(
                pads, strides, dtype, mode=conv_mode)

            workspace = util.empty(workspace_size, dtype='b', device=device)
            algo = libcudnn.getConvolutionForwardAlgorithm(
                handle, x_desc.value, W_desc.value,
                conv_desc.value, y_desc.value, pref,
                workspace_size)

            oz_dtype = 'd' if x.dtype == 'd' else 'f'
            one = np.array(1, dtype=oz_dtype).ctypes
            zero = np.array(0, dtype=oz_dtype).ctypes
            libcudnn.convolutionForward(
                handle, one.data, x_desc.value, x.data.ptr,
                W_desc.value, W.data.ptr, conv_desc.value,
                algo, workspace.data.ptr, workspace_size, zero.data,
                y_desc.value, y.data.ptr)

            return y

    def cudnn_convolve_backward_filter(
            x, y, mode='full', strides=None,
            workspace_size=1024 * 1024 * 1024,
            conv_mode=libcudnn.CUDNN_CONVOLUTION,
            pref=libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
    ):
        """
        x - (b, c_I, m_1, m_2, ..., m_N)
        y - (b, c_O, p_1, p_2, ..., p_N)
        W - (c_O, c_I, n_1, n_2, ..., n_N)
        """
        assert x.shape[0] == y.shape[0]
        assert x.ndim == y.ndim
        util._check_same_dtype(x, y)

        ndim = x.ndim - 2
        dtype = x.dtype
        device = util.get_device(x)
        xp = device.xp

        if np.issubdtype(dtype, np.complexfloating):
            with device:
                xr = xp.real(x)
                xi = -xp.imag(x)
                yr = xp.real(y)
                yi = xp.imag(y)

                Wr = cudnn_convolve_backward_filter(xr, yr, mode=mode, strides=strides,
                                                    conv_mode=conv_mode, pref=pref)
                Wr -= cudnn_convolve_backward_filter(xi, yi, mode=mode, strides=strides,
                                                     conv_mode=conv_mode, pref=pref)

                Wi = cudnn_convolve_backward_filter(xr, yi, mode=mode, strides=strides,
                                                    conv_mode=conv_mode, pref=pref)

                Wi += cudnn_convolve_backward_filter(xi, yr, mode=mode, strides=strides,
                                                     conv_mode=conv_mode, pref=pref)

                return (Wr + 1j * Wi).astype(dtype)

        W_shape = _get_cudnn_convolve_W_shape(x.shape, y.shape, mode)
        if strides is None:
            strides = (1, ) * ndim

        if mode == 'full':
            pads = tuple(n - 1 for n in W_shape[2:])
        else:
            pads = (0, ) * ndim

        with device:
            x = cp.ascontiguousarray(x)
            W = util.empty(W_shape, dtype=dtype, device=device)
            y = cp.ascontiguousarray(y)

            handle = cudnn.get_handle()
            x_desc = cudnn.create_tensor_descriptor(x)
            y_desc = cudnn.create_tensor_descriptor(y)
            W_desc = cudnn.create_filter_descriptor(W)

            conv_desc = cudnn.create_convolution_descriptor(
                pads, strides, dtype, mode=conv_mode)

            workspace = util.empty(workspace_size, dtype='b', device=device)
            algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
                handle, x_desc.value, y_desc.value,
                conv_desc.value, W_desc.value, pref,
                workspace_size)

            oz_dtype = 'd' if x.dtype == 'd' else 'f'
            one = np.array(1, dtype=oz_dtype).ctypes
            zero = np.array(0, dtype=oz_dtype).ctypes
            libcudnn.convolutionBackwardFilter_v3(
                handle, one.data, x_desc.value, x.data.ptr,
                y_desc.value, y.data.ptr, conv_desc.value,
                algo, workspace.data.ptr, workspace_size, zero.data,
                W_desc.value, W.data.ptr)

            return W

    def cudnn_convolve_backward_data(W, y,
                                     mode='full', strides=None,
                                     workspace_size=1024 * 1024 * 1024,
                                     conv_mode=libcudnn.CUDNN_CONVOLUTION,
                                     pref=libcudnn.CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT):
        """
        x - (b, c_I, m_1, m_2, ..., m_N)
        W - (c_O, c_I, n_1, n_2, ..., n_N)
        y - (b, c_O, p_1, p_2, ..., p_N)
        """
        assert W.shape[0] == y.shape[1]
        assert y.ndim == W.ndim
        util._check_same_dtype(W, y)

        ndim = W.ndim - 2
        dtype = W.dtype
        device = util.get_device(W)
        xp = device.xp

        if np.issubdtype(dtype, np.complexfloating):
            with device:
                Wr = xp.real(W)
                Wi = -xp.imag(W)
                yr = xp.real(y)
                yi = xp.imag(y)

                xr = cudnn_convolve_backward_data(Wr, yr, mode=mode, strides=strides,
                                                  conv_mode=conv_mode, pref=pref)
                xr -= cudnn_convolve_backward_data(Wi, yi, mode=mode, strides=strides,
                                                   conv_mode=conv_mode, pref=pref)

                xi = cudnn_convolve_backward_data(Wr, yi, mode=mode, strides=strides,
                                                  conv_mode=conv_mode, pref=pref)

                xi += cudnn_convolve_backward_data(Wi, yr, mode=mode, strides=strides,
                                                   conv_mode=conv_mode, pref=pref)

                return (xr + 1j * xi).astype(dtype)

        x_shape = _get_cudnn_convolve_x_shape(W.shape, y.shape, mode)
        if strides is None:
            strides = (1, ) * ndim
        if mode == 'full':
            pads = tuple(n - 1 for n in W.shape[2:])
        else:
            pads = (0, ) * ndim

        with device:
            x = util.empty(x_shape, dtype=dtype, device=device)
            W = cp.ascontiguousarray(W)
            y = cp.ascontiguousarray(y)

            handle = cudnn.get_handle()
            x_desc = cudnn.create_tensor_descriptor(x)
            y_desc = cudnn.create_tensor_descriptor(y)
            W_desc = cudnn.create_filter_descriptor(W)

            conv_desc = cudnn.create_convolution_descriptor(
                pads, strides, dtype, mode=conv_mode)

            workspace = util.empty(workspace_size, dtype='b', device=device)
            algo = libcudnn.getConvolutionBackwardDataAlgorithm(
                handle, W_desc.value, y_desc.value,
                conv_desc.value, x_desc.value, pref,
                workspace_size)

            oz_dtype = 'd' if x.dtype == 'd' else 'f'
            one = np.array(1, dtype=oz_dtype).ctypes
            zero = np.array(0, dtype=oz_dtype).ctypes
            libcudnn.convolutionBackwardData_v3(
                handle, one.data, W_desc.value, W.data.ptr,
                y_desc.value, y.data.ptr, conv_desc.value,
                algo, workspace.data.ptr, workspace_size, zero.data,
                x_desc.value, x.data.ptr)

            return x


def _get_convolve_shapes(ishape1, ishape2, axes, mode):
    ishape1 = list(ishape1)
    ishape2 = list(ishape2)

    max_ndim = max(len(ishape1), len(ishape2))
    axes = util._normalize_axes(axes, max_ndim)

    tshape1 = [1] * (max_ndim - len(ishape1)) + ishape1
    tshape2 = [1] * (max_ndim - len(ishape2)) + ishape2
    toshape = [max(t1, t2) for t1, t2 in zip(tshape1, tshape2)]
    oshape = [max(t1, t2) for t1, t2 in zip(tshape1, tshape2)]
    shift = [0] * max_ndim

    scale = 1
    for a in axes:
        if mode == 'full':
            i = tshape1[a] + tshape2[a] - 1
            oshape[a] = i
        else:
            i = max(tshape1[a], tshape2[a])
            oshape[a] = max(tshape1[a], tshape2[a]) - \
                min(tshape1[a], tshape2[a]) + 1

        tshape1[a] = i
        tshape2[a] = i
        toshape[a] = i
        shift[a] = i - oshape[a]
        scale *= i**0.5

    return tshape1, tshape2, toshape, oshape, shift, scale
back to top