Revision d38084d993f6b218862f3aa0693aacc2e3b4b1b3 authored by TUNA Caglayan on 29 April 2021, 08:44:25 UTC, committed by TUNA Caglayan on 12 May 2021, 12:26:22 UTC
1 parent e54a8ff
pytorch_backend.py
import warnings
from distutils.version import LooseVersion
try:
import torch
except ImportError as error:
message = ('Impossible to import PyTorch.\n'
'To use TensorLy with the PyTorch backend, '
'you must first install PyTorch!')
raise ImportError(message) from error
import numpy as np
from .core import Backend
class PyTorchBackend(Backend):
backend_name = 'pytorch'
@staticmethod
def context(tensor):
return {'dtype': tensor.dtype,
'device': tensor.device,
'requires_grad': tensor.requires_grad}
@staticmethod
def tensor(data, dtype=torch.float32, device='cpu', requires_grad=False):
if isinstance(data, np.ndarray):
data = data.copy()
return torch.tensor(data, dtype=dtype, device=device,
requires_grad=requires_grad)
@staticmethod
def to_numpy(tensor):
if torch.is_tensor(tensor):
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.cuda:
tensor = tensor.cpu()
return tensor.numpy()
elif isinstance(tensor, np.ndarray):
return tensor
else:
return np.asarray(tensor)
@staticmethod
def shape(tensor):
return tensor.shape
@staticmethod
def ndim(tensor):
return tensor.dim()
@staticmethod
def arange(start, stop=None, step=1.0, *args, **kwargs):
if stop is None:
return torch.arange(start=0., end=float(start), step=float(step), *args, **kwargs)
else:
return torch.arange(float(start), float(stop), float(step), *args, **kwargs)
@staticmethod
def clip(tensor, a_min=None, a_max=None, inplace=False):
if a_max is None:
a_max = torch.max(tensor)
if a_min is None:
a_min = torch.min(tensor)
if inplace:
return torch.clamp(tensor, a_min, a_max, out=tensor)
else:
return torch.clamp(tensor, a_min, a_max)
@staticmethod
def all(tensor):
return torch.sum(tensor != 0)
def transpose(self, tensor, axes=None):
axes = axes or list(range(self.ndim(tensor)))[::-1]
return tensor.permute(*axes)
@staticmethod
def copy(tensor):
return tensor.clone()
@staticmethod
def norm(tensor, order=None, axis=None):
# pytorch does not accept `None` for any keyword arguments. additionally,
# pytorch doesn't seems to support keyword arguments in the first place
kwds = {}
if axis is not None:
kwds['dim'] = axis
if order and order != 'inf':
kwds['p'] = order
if order == 'inf':
res = torch.max(torch.abs(tensor), **kwds)
if axis is not None:
return res[0] # ignore indices output
return res
return torch.norm(tensor, **kwds)
@staticmethod
def mean(tensor, axis=None):
if axis is None:
return torch.mean(tensor)
else:
return torch.mean(tensor, dim=axis)
@staticmethod
def sum(tensor, axis=None):
if axis is None:
return torch.sum(tensor)
else:
return torch.sum(tensor, dim=axis)
@staticmethod
def flip(tensor, axis=None):
if isinstance(axis, int):
axis = [axis]
if axis is None:
return torch.flip(tensor, dims=[i for i in range(tensor.ndim)])
else:
return torch.flip(tensor, dims=axis)
@staticmethod
def concatenate(tensors, axis=0):
return torch.cat(tensors, dim=axis)
@staticmethod
def argmin(input, axis=None):
return torch.argmin(input, dim=axis)
@staticmethod
def argmax(input, axis=None):
return torch.argmax(input, dim=axis)
@staticmethod
def stack(arrays, axis=0):
return torch.stack(arrays, dim=axis)
@staticmethod
def sort(tensor, axis, descending = False):
if axis is None:
tensor = tensor.flatten()
axis = -1
return torch.sort(tensor, dim=axis, descending = descending).values
@staticmethod
def update_index(tensor, index, values):
tensor.index_put_(index, values)
def solve(self, matrix1, matrix2):
"""Legacy only, deprecated from PyTorch 1.8.0
Solve a linear system of equation
Notes
-----
Previously, this was implemented as follows::
if self.ndim(matrix2) < 2:
# Currently, gesv doesn't support vectors for matrix2
# So we instead solve a least square problem...
solution, _ = torch.gels(matrix2, matrix1)
else:
solution, _ = torch.gesv(matrix2, matrix1)
return solution
Deprecated from PyTorch 1.8.0
"""
if self.ndim(matrix2) < 2:
# Currently, solve doesn't support vectors for matrix2
solution, _ = torch.solve(matrix2.unsqueeze(1), matrix1)
else:
solution, _ = torch.solve(matrix2, matrix1)
return solution
@staticmethod
def eigh(tensor):
"""Legacy only, deprecated from PyTorch 1.8.0"""
return torch.symeig(tensor, eigenvectors=True)
@staticmethod
def svd(matrix, full_matrices=False):
full_matrices = (not full_matrices)
return torch.svd(matrix, some=full_matrices, compute_uv=True)
# Register the other functions
for name in ['float64', 'float32', 'int64', 'int32', 'complex128', 'complex64',
'is_tensor', 'ones', 'zeros', 'any',
'zeros_like', 'reshape', 'eye', 'max', 'min', 'prod', 'abs',
'sqrt', 'sign', 'where', 'conj', 'diag', 'finfo', 'einsum', 'log2', 'sin', 'cos']:
PyTorchBackend.register_method(name, getattr(torch, name))
PyTorchBackend.register_method('dot', torch.matmul)
# PyTorch 1.8.0 has a much better NumPy interface but somoe haven't updated yet
if LooseVersion(torch.__version__) < LooseVersion('1.8.0'):
# Old version, will be removed in the future
warnings.warn(f'You are using an old version of PyTorch ({torch.__version__}). '
'We recommend upgrading to a newest one, e.g. >1.8.0.')
PyTorchBackend.register_method('moveaxis', getattr(torch, 'movedim'))
PyTorchBackend.register_method('qr', getattr(torch, 'qr'))
else:
# New PyTorch NumPy interface
for name in ['kron', 'moveaxis']:
PyTorchBackend.register_method(name, getattr(torch, name))
for name in ['solve', 'qr', 'svd', 'eigh']:
PyTorchBackend.register_method(name, getattr(torch.linalg, name))
Computing file changes ...