https://github.com/mikgroup/sigpy
Tip revision: bf7156311861e84d7319d22b8acabc3f73582e20 authored by Jonathan Martin on 16 October 2023, 01:33:45 UTC
Create python-package.yml
Create python-package.yml
Tip revision: bf71563
conv.py
# -*- coding: utf-8 -*-
"""Convolution functions with multi-dimension, and multi-channel support.
"""
import numpy as np
import scipy.signal as signal
from sigpy import backend, config, util
__all__ = ["convolve", "convolve_data_adjoint", "convolve_filter_adjoint"]
def convolve(data, filt, mode="full", strides=None, multi_channel=False):
r"""Convolution that supports multi-dimensional and multi-channel inputs.
This function follows the signal processing definition of convolution.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
data (array): data array of shape:
:math:`[..., m_1, ..., m_D]` if multi_channel is False,
:math:`[..., c_i, m_1, ..., m_D]` otherwise.
filt (array): filter array of shape:
:math:`[n_1, ..., n_D]` if multi_channel is False
:math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
mode (str): {'full', 'valid'}.
strides (None or tuple of ints): convolution strides of length D.
multi_channel (bool): specify if input/output has multiple channels.
Returns:
array: output array of shape:
:math:`[..., p_1, ..., p_D]` if multi_channel is False,
:math:`[..., c_o, p_1, ..., p_D]` otherwise.
"""
xp = backend.get_array_module(data)
if xp == np:
output = _convolve(
data, filt, mode=mode, strides=strides, multi_channel=multi_channel
)
else: # pragma: no cover
if config.cudnn_enabled:
if np.issubdtype(data.dtype, np.floating):
output = _convolve_cuda(
data,
filt,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
output = _complex(
_convolve_cuda,
data,
filt,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
raise RuntimeError(
"cudnn must be installed to perform convolution on GPU."
)
return output
def convolve_data_adjoint(
output, filt, data_shape, mode="full", strides=None, multi_channel=False
):
"""Adjoint convolution operation with respect to data.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
output (array): output array of shape
:math:`[..., p_1, ..., p_D]` if multi_channel is False,
:math:`[..., c_o, p_1, ..., p_D]` otherwise.
filt (array): filter array of shape
:math:`[n_1, ..., n_D]` if multi_channel is False
:math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
mode (str): {'full', 'valid'}.
strides (None or tuple of ints): convolution strides of length D.
multi_channel (bool): specify if input/output has multiple channels.
multi_channel (bool): specify if data/output has multiple channels.
mode (str): {'full', 'valid'}.
Returns:
array: data array of shape
:math:`[..., m_1, ..., m_D]` if multi_channel is False,
:math:`[..., c_i, m_1, ..., m_D]` otherwise.
"""
data_shape = tuple(data_shape)
xp = backend.get_array_module(output)
if xp == np:
data = _convolve_data_adjoint(
output,
filt,
data_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else: # pragma: no cover
if config.cudnn_enabled:
if np.issubdtype(output.dtype, np.floating):
data = _convolve_data_adjoint_cuda(
output,
filt,
data_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
data = _complex(
_convolve_data_adjoint_cuda,
output,
filt.conj(),
data_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
raise RuntimeError(
"cudnn must be installed to perform convolution on GPU."
)
return data
def convolve_filter_adjoint(
output, data, filt_shape, mode="full", strides=None, multi_channel=False
):
"""Adjoint convolution operation with respect to filter.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
output (array): output array of shape:
:math:`[..., p_1, ..., p_D]` if multi_channel is False,
:math:`[..., c_o, p_1, ..., p_D]` otherwise.
data (array): data array of shape:
:math:`[..., m_1, ..., m_D]` if multi_channel is False,
:math:`[..., c_i, m_1, ..., m_D]` otherwise.
mode (str): {'full', 'valid'}.
strides (None or tuple of ints): convolution strides of length D.
multi_channel (bool): specify if input/output has multiple channels.
Returns:
array: filter array of shape:
:math:`[n_1, ..., n_D]` if multi_channel is False
:math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
"""
filt_shape = tuple(filt_shape)
xp = backend.get_array_module(data)
if xp == np:
filt = _convolve_filter_adjoint(
output,
data,
filt_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else: # pragma: no cover
if config.cudnn_enabled:
if np.issubdtype(output.dtype, np.floating):
filt = _convolve_filter_adjoint_cuda(
output,
data,
filt_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
filt = _complex(
_convolve_filter_adjoint_cuda,
output,
data.conj(),
filt_shape,
mode=mode,
strides=strides,
multi_channel=multi_channel,
)
else:
raise RuntimeError(
"cudnn must be installed to perform convolution on GPU."
)
return filt
def _get_convolve_params(data_shape, filt_shape, mode, strides, multi_channel):
D = len(filt_shape) - 2 * multi_channel
m = tuple(data_shape[-D:])
n = tuple(filt_shape[-D:])
b = tuple(data_shape[: -D - multi_channel])
B = util.prod(b)
if multi_channel:
if filt_shape[-D - 1] != data_shape[-D - 1]:
raise ValueError(
"Data channel mismatch, "
"got {} from data and {} from filt.".format(
data_shape[-D - 1], filt_shape[-D - 1]
)
)
c_i = filt_shape[-D - 1]
c_o = filt_shape[-D - 2]
else:
c_i = 1
c_o = 1
if strides is None:
s = (1,) * D
else:
if len(strides) != D:
raise ValueError("Strides must have length {}.".format(D))
s = tuple(strides)
if mode == "full":
p = tuple(
(m_d + n_d - 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)
)
elif mode == "valid":
if any(m_d >= n_d for m_d, n_d in zip(m, n)) and any(
m_d < n_d for m_d, n_d in zip(m, n)
):
raise ValueError(
"In valid mode, either data or filter must be "
"at least as large as the other in every axis."
)
p = tuple(
(m_d - n_d + 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)
)
else:
raise ValueError("Invalid mode, got {}".format(mode))
return D, b, B, m, n, s, c_i, c_o, p
def _convolve(data, filt, mode="full", strides=None, multi_channel=False):
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt.shape, mode, strides, multi_channel
)
# Normalize shapes.
data = data.reshape((B, c_i) + m)
filt = filt.reshape((c_o, c_i) + n)
output = np.zeros((B, c_o) + p, dtype=data.dtype)
slc = tuple(slice(None, None, s_d) for s_d in s)
for k in range(B):
for j in range(c_o):
for i in range(c_i):
output[k, j] += signal.convolve(
data[k, i], filt[j, i], mode=mode
)[slc]
# Reshape.
if multi_channel:
output = output.reshape(b + (c_o,) + p)
else:
output = output.reshape(b + p)
return output
def _convolve_data_adjoint(
output, filt, data_shape, mode="full", strides=None, multi_channel=False
):
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data_shape, filt.shape, mode, strides, multi_channel
)
# Normalize shapes.
output = output.reshape((B, c_o) + p)
filt = filt.reshape((c_o, c_i) + n)
data = np.zeros((B, c_i) + m, dtype=output.dtype)
slc = tuple(slice(None, None, s_d) for s_d in s)
if mode == "full":
output_kj = np.zeros(
[m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype
)
adjoint_mode = "valid"
elif mode == "valid":
output_kj = np.zeros(
[max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)],
dtype=output.dtype,
)
if all(m_d >= n_d for m_d, n_d in zip(m, n)):
adjoint_mode = "full"
else:
adjoint_mode = "valid"
for k in range(B):
for j in range(c_o):
for i in range(c_i):
output_kj[slc] = output[k, j]
data[k, i] += signal.correlate(
output_kj, filt[j, i], mode=adjoint_mode
)
# Reshape.
data = data.reshape(data_shape)
return data
def _convolve_filter_adjoint(
output, data, filt_shape, mode="full", strides=None, multi_channel=False
):
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt_shape, mode, strides, multi_channel
)
# Normalize shapes.
data = data.reshape((B, c_i) + m)
output = output.reshape((B, c_o) + p)
slc = tuple(slice(None, None, s_d) for s_d in s)
if mode == "full":
output_kj = np.zeros(
[m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype
)
adjoint_mode = "valid"
elif mode == "valid":
output_kj = np.zeros(
[max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)],
dtype=output.dtype,
)
if all(m_d >= n_d for m_d, n_d in zip(m, n)):
adjoint_mode = "valid"
else:
adjoint_mode = "full"
filt = np.zeros((c_o, c_i) + n, dtype=output.dtype)
for k in range(B):
for j in range(c_o):
for i in range(c_i):
output_kj[slc] = output[k, j]
filt[j, i] += signal.correlate(
output_kj, data[k, i], mode=adjoint_mode
)
# Reshape.
filt = filt.reshape(filt_shape)
return filt
if config.cudnn_enabled: # pragma: no cover
from cupy import cudnn
def _complex(func, data1, data2, *kargs, **kwargs):
"""Helper function to convert func to support complex floats."""
xp = backend.get_array_module(data1)
data1r = xp.real(data1)
data1i = xp.imag(data1)
data2r = xp.real(data2)
data2i = xp.imag(data2)
outputr = func(data1r, data2r, *kargs, **kwargs)
outputr -= func(data1i, data2i, *kargs, **kwargs)
outputi = func(data1i, data2r, *kargs, **kwargs)
outputi += func(data1r, data2i, *kargs, **kwargs)
output = outputr + 1j * outputi
output = output.astype(data1.dtype, copy=False)
return output
def _convolve_cuda(
data, filt, mode="full", strides=None, multi_channel=False
):
xp = backend.get_array_module(data)
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt.shape, mode, strides, multi_channel
)
if D == 1:
return _convolve_cuda(
xp.expand_dims(data, -1),
xp.expand_dims(filt, -1),
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel,
).squeeze(-1)
elif D > 3:
raise ValueError(
f"cuDNN convolution only supports 1, 2, or 3D, got {D}."
)
dilations = (1,) * D
groups = 1
auto_tune = True
tensor_core = "auto"
if mode == "full":
pads = tuple(n_d - 1 for n_d in n)
elif mode == "valid":
pads = (0,) * D
data = data.reshape((B, c_i) + m)
filt = filt.reshape((c_o, c_i) + n)
output = xp.empty((B, c_o) + p, dtype=data.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_forward(
data,
filt,
None,
output,
pads,
s,
dilations,
groups,
auto_tune=auto_tune,
tensor_core=tensor_core,
)
# Reshape.
if multi_channel:
output = output.reshape(b + (c_o,) + p)
else:
output = output.reshape(b + p)
return output
def _convolve_data_adjoint_cuda(
output,
filt,
data_shape,
mode="full",
strides=None,
multi_channel=False,
):
xp = backend.get_array_module(output)
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data_shape, filt.shape, mode, strides, multi_channel
)
if D == 1:
return _convolve_data_adjoint_cuda(
xp.expand_dims(output, -1),
xp.expand_dims(filt, -1),
list(data_shape) + [1],
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel,
).squeeze(-1)
elif D > 3:
raise ValueError(
f"cuDNN convolution only supports 1, 2 or 3D, got {D}."
)
dilations = (1,) * D
groups = 1
auto_tune = True
tensor_core = "auto"
deterministic = False
if mode == "full":
pads = tuple(n_d - 1 for n_d in n)
elif mode == "valid":
pads = (0,) * D
output = output.reshape((B, c_o) + p)
filt = filt.reshape((c_o, c_i) + n)
data = xp.empty((B, c_i) + m, dtype=output.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_backward_data(
filt,
output,
None,
data,
pads,
s,
dilations,
groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core,
)
# Reshape.
data = data.reshape(data_shape)
return data
def _convolve_filter_adjoint_cuda(
output,
data,
filt_shape,
mode="full",
strides=None,
multi_channel=False,
):
xp = backend.get_array_module(data)
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt_shape, mode, strides, multi_channel
)
if D == 1:
return _convolve_filter_adjoint_cuda(
xp.expand_dims(output, -1),
xp.expand_dims(data, -1),
list(filt_shape) + [1],
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel,
).squeeze(-1)
elif D > 3:
raise ValueError(
f"cuDNN convolution only supports 1, 2 or 3D, got {D}."
)
dilations = (1,) * D
groups = 1
auto_tune = True
tensor_core = "auto"
deterministic = False
if mode == "full":
pads = tuple(n_d - 1 for n_d in n)
elif mode == "valid":
pads = (0,) * D
data = data.reshape((B, c_i) + m)
output = output.reshape((B, c_o) + p)
filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
cudnn.convolution_backward_filter(
data,
output,
filt,
pads,
s,
dilations,
groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core,
)
filt = util.flip(filt, axes=range(-D, 0))
filt = filt.reshape(filt_shape)
return filt