https://github.com/tensorly/tensorly
Tip revision: a26ffe09a1468b328bd0e26e432e5fb1cc334efc authored by Jean Kossaifi on 08 March 2023, 06:36:00 UTC
Merge pull request #491 from braun-steven/feature/logsumexp
Merge pull request #491 from braun-steven/feature/logsumexp
Tip revision: a26ffe0
pytorch_backend.py
import warnings
from packaging.version import Version
try:
import torch
except ImportError as error:
message = (
"Impossible to import PyTorch.\n"
"To use TensorLy with the PyTorch backend, "
"you must first install PyTorch!"
)
raise ImportError(message) from error
import numpy as np
from .core import (
Backend,
backend_types,
backend_basic_math,
backend_array,
)
linalg_lstsq_avail = Version(torch.__version__) >= Version("1.9.0")
class PyTorchBackend(Backend, backend_name="pytorch"):
@staticmethod
def context(tensor):
return {
"dtype": tensor.dtype,
"device": tensor.device,
"requires_grad": tensor.requires_grad,
}
@staticmethod
def tensor(data, dtype=torch.float32, device="cpu", requires_grad=False):
if isinstance(data, np.ndarray):
data = data.copy()
return torch.tensor(
data, dtype=dtype, device=device, requires_grad=requires_grad
)
@staticmethod
def to_numpy(tensor):
if torch.is_tensor(tensor):
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.cuda:
tensor = tensor.cpu()
return tensor.numpy()
elif isinstance(tensor, np.ndarray):
return tensor
else:
return np.asarray(tensor)
@staticmethod
def shape(tensor):
return tuple(tensor.shape)
@staticmethod
def ndim(tensor):
return tensor.dim()
@staticmethod
def arange(start, stop=None, step=1.0, *args, **kwargs):
if stop is None:
return torch.arange(
start=0.0, end=float(start), step=float(step), *args, **kwargs
)
else:
return torch.arange(float(start), float(stop), float(step), *args, **kwargs)
@staticmethod
def clip(tensor, a_min=None, a_max=None, inplace=False):
if inplace:
return torch.clip(tensor, a_min, a_max, out=tensor)
else:
return torch.clip(tensor, a_min, a_max)
@staticmethod
def all(tensor):
return torch.sum(tensor != 0)
def transpose(self, tensor, axes=None):
axes = axes or list(range(self.ndim(tensor)))[::-1]
return tensor.permute(*axes)
@staticmethod
def copy(tensor):
return tensor.clone()
@staticmethod
def norm(tensor, order=None, axis=None):
# pytorch does not accept `None` for any keyword arguments. additionally,
# pytorch doesn't seems to support keyword arguments in the first place
kwds = {}
if axis is not None:
kwds["dim"] = axis
if order and order != "inf":
kwds["p"] = order
if order == "inf":
res = torch.max(torch.abs(tensor), **kwds)
if axis is not None:
return res[0] # ignore indices output
return res
return torch.norm(tensor, **kwds)
@staticmethod
def dot(a, b):
if a.ndim > 2 and b.ndim > 2:
return torch.tensordot(a, b, dims=([-1], [-2]))
if not a.ndim or not b.ndim:
return a * b
return torch.matmul(a, b)
@staticmethod
def tensordot(a, b, axes=2, **kwargs):
return torch.tensordot(a, b, dims=axes, **kwargs)
@staticmethod
def mean(tensor, axis=None):
if axis is None:
return torch.mean(tensor)
else:
return torch.mean(tensor, dim=axis)
@staticmethod
def sum(tensor, axis=None, keepdims=False):
if axis is None:
axis = tuple(range(tensor.ndim))
return torch.sum(tensor, dim=axis, keepdim=keepdims)
@staticmethod
def max(tensor, axis=None):
if axis is None:
return torch.max(tensor)
else:
return torch.max(tensor, dim=axis)[0]
@staticmethod
def flip(tensor, axis=None):
if isinstance(axis, int):
axis = [axis]
if axis is None:
return torch.flip(tensor, dims=[i for i in range(tensor.ndim)])
else:
return torch.flip(tensor, dims=axis)
@staticmethod
def concatenate(tensors, axis=0):
return torch.cat(tensors, dim=axis)
@staticmethod
def argmin(input, axis=None):
return torch.argmin(input, dim=axis)
@staticmethod
def argsort(input, axis=None):
return torch.argsort(input, dim=axis)
@staticmethod
def argmax(input, axis=None):
return torch.argmax(input, dim=axis)
@staticmethod
def stack(arrays, axis=0):
return torch.stack(arrays, dim=axis)
@staticmethod
def diag(tensor, k=0):
return torch.diag(tensor, diagonal=k)
@staticmethod
def sort(tensor, axis):
if axis is None:
tensor = tensor.flatten()
axis = -1
return torch.sort(tensor, dim=axis).values
@staticmethod
def update_index(tensor, index, values):
tensor.index_put_(index, values)
def solve(self, matrix1, matrix2):
"""Legacy only, deprecated from PyTorch 1.8.0
Solve a linear system of equation
Notes
-----
Previously, this was implemented as follows::
if self.ndim(matrix2) < 2:
# Currently, gesv doesn't support vectors for matrix2
# So we instead solve a least square problem...
solution, _ = torch.gels(matrix2, matrix1)
else:
solution, _ = torch.gesv(matrix2, matrix1)
return solution
Deprecated from PyTorch 1.8.0
"""
if self.ndim(matrix2) < 2:
# Currently, solve doesn't support vectors for matrix2
solution, _ = torch.solve(matrix2.unsqueeze(1), matrix1)
else:
solution, _ = torch.solve(matrix2, matrix1)
return solution
@staticmethod
def lstsq(a, b):
if linalg_lstsq_avail:
x, residuals, _, _ = torch.linalg.lstsq(a, b, rcond=None, driver="gelsd")
return x, residuals
else:
n = a.shape[1]
sol = torch.lstsq(b, a)[0]
x = sol[:n]
residuals = torch.norm(sol[n:], dim=0) ** 2
return x, residuals if torch.matrix_rank(a) == n else torch.tensor(
[], device=x.device
)
@staticmethod
def eigh(tensor):
"""Legacy only, deprecated from PyTorch 1.8.0"""
return torch.symeig(tensor, eigenvectors=True)
@staticmethod
def sign(tensor):
"""torch.sign does not support complex numbers."""
return torch.sgn(tensor)
@staticmethod
def svd(matrix, full_matrices=True):
some = not full_matrices
u, s, v = torch.svd(matrix, some=some, compute_uv=True)
return u, s, v.transpose(-2, -1).conj()
@staticmethod
def logsumexp(tensor, axis=0):
return torch.logsumexp(tensor, dim=axis)
# Register the other functions
for name in (
backend_types
+ backend_basic_math
+ backend_array
+ [
"nan",
"is_tensor",
"trace",
"conj",
"finfo",
"log2",
"digamma",
]
):
PyTorchBackend.register_method(name, getattr(torch, name))
# PyTorch 1.8.0 has a much better NumPy interface but somoe haven't updated yet
if Version(torch.__version__) < Version("1.8.0"):
# Old version, will be removed in the future
warnings.warn(
f"You are using an old version of PyTorch ({torch.__version__}). "
"We recommend upgrading to a newest one, e.g. >1.8.0."
)
PyTorchBackend.register_method("moveaxis", getattr(torch, "movedim"))
PyTorchBackend.register_method("qr", getattr(torch, "qr"))
else:
# New PyTorch NumPy interface
for name in ["kron", "moveaxis"]:
PyTorchBackend.register_method(name, getattr(torch, name))
for name in ["solve", "qr", "svd", "eigh"]:
PyTorchBackend.register_method(name, getattr(torch.linalg, name))