https://github.com/mikgroup/sigpy
Raw File
Tip revision: bfd564bc9942f06dab6288923a2f9e7516237bdd authored by Jonathan Martin on 16 October 2023, 22:09:11 UTC
bump version 0.1.25 -> 0.1.26
Tip revision: bfd564b
conv.py
# -*- coding: utf-8 -*-
"""Convolution functions with  multi-dimension, and multi-channel support.

"""
import numpy as np
import scipy.signal as signal

from sigpy import backend, config, util

__all__ = ["convolve", "convolve_data_adjoint", "convolve_filter_adjoint"]


def convolve(data, filt, mode="full", strides=None, multi_channel=False):
    r"""Convolution that supports multi-dimensional and multi-channel inputs.

    This function follows the signal processing definition of convolution.
    Note that the cuDNN version only supports inputs with D=1, 2 or 3.

    Args:
        data (array): data array of shape:
            :math:`[..., m_1, ..., m_D]` if multi_channel is False,
            :math:`[..., c_i, m_1, ..., m_D]` otherwise.
        filt (array): filter array of shape:
            :math:`[n_1, ..., n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.

    Returns:
        array: output array of shape:
            :math:`[..., p_1, ..., p_D]` if multi_channel is False,
            :math:`[..., c_o, p_1, ..., p_D]` otherwise.

    """
    xp = backend.get_array_module(data)
    if xp == np:
        output = _convolve(
            data, filt, mode=mode, strides=strides, multi_channel=multi_channel
        )
    else:  # pragma: no cover
        if config.cudnn_enabled:
            if np.issubdtype(data.dtype, np.floating):
                output = _convolve_cuda(
                    data,
                    filt,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
            else:
                output = _complex(
                    _convolve_cuda,
                    data,
                    filt,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
        else:
            raise RuntimeError(
                "cudnn must be installed to perform convolution on GPU."
            )

    return output


def convolve_data_adjoint(
    output, filt, data_shape, mode="full", strides=None, multi_channel=False
):
    """Adjoint convolution operation with respect to data.

    Note that the cuDNN version only supports inputs with D=1, 2 or 3.

    Args:
        output (array): output array of shape
            :math:`[..., p_1, ..., p_D]` if multi_channel is False,
            :math:`[..., c_o, p_1, ..., p_D]` otherwise.
        filt (array): filter array of shape
            :math:`[n_1, ..., n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.
        multi_channel (bool): specify if data/output has multiple channels.
        mode (str): {'full', 'valid'}.

    Returns:
        array: data array of shape
            :math:`[..., m_1, ..., m_D]` if multi_channel is False,
            :math:`[..., c_i, m_1, ..., m_D]` otherwise.

    """
    data_shape = tuple(data_shape)

    xp = backend.get_array_module(output)
    if xp == np:
        data = _convolve_data_adjoint(
            output,
            filt,
            data_shape,
            mode=mode,
            strides=strides,
            multi_channel=multi_channel,
        )
    else:  # pragma: no cover
        if config.cudnn_enabled:
            if np.issubdtype(output.dtype, np.floating):
                data = _convolve_data_adjoint_cuda(
                    output,
                    filt,
                    data_shape,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
            else:
                data = _complex(
                    _convolve_data_adjoint_cuda,
                    output,
                    filt.conj(),
                    data_shape,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
        else:
            raise RuntimeError(
                "cudnn must be installed to perform convolution on GPU."
            )

    return data


def convolve_filter_adjoint(
    output, data, filt_shape, mode="full", strides=None, multi_channel=False
):
    """Adjoint convolution operation with respect to filter.

    Note that the cuDNN version only supports inputs with D=1, 2 or 3.

    Args:
        output (array): output array of shape:
            :math:`[..., p_1, ..., p_D]` if multi_channel is False,
            :math:`[..., c_o, p_1, ..., p_D]` otherwise.
        data (array): data array of shape:
            :math:`[..., m_1, ..., m_D]` if multi_channel is False,
            :math:`[..., c_i, m_1, ..., m_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.

    Returns:
        array: filter array of shape:
            :math:`[n_1, ..., n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, ..., n_D]` otherwise.

    """
    filt_shape = tuple(filt_shape)
    xp = backend.get_array_module(data)
    if xp == np:
        filt = _convolve_filter_adjoint(
            output,
            data,
            filt_shape,
            mode=mode,
            strides=strides,
            multi_channel=multi_channel,
        )
    else:  # pragma: no cover
        if config.cudnn_enabled:
            if np.issubdtype(output.dtype, np.floating):
                filt = _convolve_filter_adjoint_cuda(
                    output,
                    data,
                    filt_shape,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
            else:
                filt = _complex(
                    _convolve_filter_adjoint_cuda,
                    output,
                    data.conj(),
                    filt_shape,
                    mode=mode,
                    strides=strides,
                    multi_channel=multi_channel,
                )
        else:
            raise RuntimeError(
                "cudnn must be installed to perform convolution on GPU."
            )

    return filt


def _get_convolve_params(data_shape, filt_shape, mode, strides, multi_channel):
    D = len(filt_shape) - 2 * multi_channel
    m = tuple(data_shape[-D:])
    n = tuple(filt_shape[-D:])
    b = tuple(data_shape[: -D - multi_channel])
    B = util.prod(b)

    if multi_channel:
        if filt_shape[-D - 1] != data_shape[-D - 1]:
            raise ValueError(
                "Data channel mismatch, "
                "got {} from data and {} from filt.".format(
                    data_shape[-D - 1], filt_shape[-D - 1]
                )
            )

        c_i = filt_shape[-D - 1]
        c_o = filt_shape[-D - 2]
    else:
        c_i = 1
        c_o = 1

    if strides is None:
        s = (1,) * D
    else:
        if len(strides) != D:
            raise ValueError("Strides must have length {}.".format(D))

        s = tuple(strides)

    if mode == "full":
        p = tuple(
            (m_d + n_d - 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)
        )
    elif mode == "valid":
        if any(m_d >= n_d for m_d, n_d in zip(m, n)) and any(
            m_d < n_d for m_d, n_d in zip(m, n)
        ):
            raise ValueError(
                "In valid mode, either data or filter must be "
                "at least as large as the other in every axis."
            )

        p = tuple(
            (m_d - n_d + 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)
        )
    else:
        raise ValueError("Invalid mode, got {}".format(mode))

    return D, b, B, m, n, s, c_i, c_o, p


def _convolve(data, filt, mode="full", strides=None, multi_channel=False):
    D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
        data.shape, filt.shape, mode, strides, multi_channel
    )

    # Normalize shapes.
    data = data.reshape((B, c_i) + m)
    filt = filt.reshape((c_o, c_i) + n)
    output = np.zeros((B, c_o) + p, dtype=data.dtype)
    slc = tuple(slice(None, None, s_d) for s_d in s)

    for k in range(B):
        for j in range(c_o):
            for i in range(c_i):
                output[k, j] += signal.convolve(
                    data[k, i], filt[j, i], mode=mode
                )[slc]

    # Reshape.
    if multi_channel:
        output = output.reshape(b + (c_o,) + p)
    else:
        output = output.reshape(b + p)

    return output


def _convolve_data_adjoint(
    output, filt, data_shape, mode="full", strides=None, multi_channel=False
):
    D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
        data_shape, filt.shape, mode, strides, multi_channel
    )

    # Normalize shapes.
    output = output.reshape((B, c_o) + p)
    filt = filt.reshape((c_o, c_i) + n)
    data = np.zeros((B, c_i) + m, dtype=output.dtype)
    slc = tuple(slice(None, None, s_d) for s_d in s)
    if mode == "full":
        output_kj = np.zeros(
            [m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype
        )
        adjoint_mode = "valid"
    elif mode == "valid":
        output_kj = np.zeros(
            [max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)],
            dtype=output.dtype,
        )
        if all(m_d >= n_d for m_d, n_d in zip(m, n)):
            adjoint_mode = "full"
        else:
            adjoint_mode = "valid"

    for k in range(B):
        for j in range(c_o):
            for i in range(c_i):
                output_kj[slc] = output[k, j]
                data[k, i] += signal.correlate(
                    output_kj, filt[j, i], mode=adjoint_mode
                )

    # Reshape.
    data = data.reshape(data_shape)
    return data


def _convolve_filter_adjoint(
    output, data, filt_shape, mode="full", strides=None, multi_channel=False
):
    D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
        data.shape, filt_shape, mode, strides, multi_channel
    )

    # Normalize shapes.
    data = data.reshape((B, c_i) + m)
    output = output.reshape((B, c_o) + p)
    slc = tuple(slice(None, None, s_d) for s_d in s)
    if mode == "full":
        output_kj = np.zeros(
            [m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype
        )
        adjoint_mode = "valid"
    elif mode == "valid":
        output_kj = np.zeros(
            [max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)],
            dtype=output.dtype,
        )
        if all(m_d >= n_d for m_d, n_d in zip(m, n)):
            adjoint_mode = "valid"
        else:
            adjoint_mode = "full"

    filt = np.zeros((c_o, c_i) + n, dtype=output.dtype)
    for k in range(B):
        for j in range(c_o):
            for i in range(c_i):
                output_kj[slc] = output[k, j]
                filt[j, i] += signal.correlate(
                    output_kj, data[k, i], mode=adjoint_mode
                )

    # Reshape.
    filt = filt.reshape(filt_shape)
    return filt


if config.cudnn_enabled:  # pragma: no cover
    from cupy import cudnn

    def _complex(func, data1, data2, *kargs, **kwargs):
        """Helper function to convert func to support complex floats."""
        xp = backend.get_array_module(data1)
        data1r = xp.real(data1)
        data1i = xp.imag(data1)
        data2r = xp.real(data2)
        data2i = xp.imag(data2)

        outputr = func(data1r, data2r, *kargs, **kwargs)
        outputr -= func(data1i, data2i, *kargs, **kwargs)
        outputi = func(data1i, data2r, *kargs, **kwargs)
        outputi += func(data1r, data2i, *kargs, **kwargs)

        output = outputr + 1j * outputi
        output = output.astype(data1.dtype, copy=False)
        return output

    def _convolve_cuda(
        data, filt, mode="full", strides=None, multi_channel=False
    ):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt.shape, mode, strides, multi_channel
        )

        if D == 1:
            return _convolve_cuda(
                xp.expand_dims(data, -1),
                xp.expand_dims(filt, -1),
                mode=mode,
                strides=list(strides) + [1] if strides is not None else None,
                multi_channel=multi_channel,
            ).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f"cuDNN convolution only supports 1, 2, or 3D, got {D}."
            )

        dilations = (1,) * D
        groups = 1
        auto_tune = True
        tensor_core = "auto"
        if mode == "full":
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == "valid":
            pads = (0,) * D

        data = data.reshape((B, c_i) + m)
        filt = filt.reshape((c_o, c_i) + n)
        output = xp.empty((B, c_o) + p, dtype=data.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_forward(
            data,
            filt,
            None,
            output,
            pads,
            s,
            dilations,
            groups,
            auto_tune=auto_tune,
            tensor_core=tensor_core,
        )

        # Reshape.
        if multi_channel:
            output = output.reshape(b + (c_o,) + p)
        else:
            output = output.reshape(b + p)

        return output

    def _convolve_data_adjoint_cuda(
        output,
        filt,
        data_shape,
        mode="full",
        strides=None,
        multi_channel=False,
    ):
        xp = backend.get_array_module(output)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data_shape, filt.shape, mode, strides, multi_channel
        )

        if D == 1:
            return _convolve_data_adjoint_cuda(
                xp.expand_dims(output, -1),
                xp.expand_dims(filt, -1),
                list(data_shape) + [1],
                mode=mode,
                strides=list(strides) + [1] if strides is not None else None,
                multi_channel=multi_channel,
            ).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f"cuDNN convolution only supports 1, 2 or 3D, got {D}."
            )

        dilations = (1,) * D
        groups = 1
        auto_tune = True
        tensor_core = "auto"
        deterministic = False
        if mode == "full":
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == "valid":
            pads = (0,) * D

        output = output.reshape((B, c_o) + p)
        filt = filt.reshape((c_o, c_i) + n)
        data = xp.empty((B, c_i) + m, dtype=output.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_backward_data(
            filt,
            output,
            None,
            data,
            pads,
            s,
            dilations,
            groups,
            deterministic=deterministic,
            auto_tune=auto_tune,
            tensor_core=tensor_core,
        )

        # Reshape.
        data = data.reshape(data_shape)

        return data

    def _convolve_filter_adjoint_cuda(
        output,
        data,
        filt_shape,
        mode="full",
        strides=None,
        multi_channel=False,
    ):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt_shape, mode, strides, multi_channel
        )
        if D == 1:
            return _convolve_filter_adjoint_cuda(
                xp.expand_dims(output, -1),
                xp.expand_dims(data, -1),
                list(filt_shape) + [1],
                mode=mode,
                strides=list(strides) + [1] if strides is not None else None,
                multi_channel=multi_channel,
            ).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f"cuDNN convolution only supports 1, 2 or 3D, got {D}."
            )

        dilations = (1,) * D
        groups = 1
        auto_tune = True
        tensor_core = "auto"
        deterministic = False
        if mode == "full":
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == "valid":
            pads = (0,) * D

        data = data.reshape((B, c_i) + m)
        output = output.reshape((B, c_o) + p)
        filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
        cudnn.convolution_backward_filter(
            data,
            output,
            filt,
            pads,
            s,
            dilations,
            groups,
            deterministic=deterministic,
            auto_tune=auto_tune,
            tensor_core=tensor_core,
        )
        filt = util.flip(filt, axes=range(-D, 0))
        filt = filt.reshape(filt_shape)

        return filt
back to top