https://github.com/tensorly/tensorly
Tip revision: a26ffe09a1468b328bd0e26e432e5fb1cc334efc authored by Jean Kossaifi on 08 March 2023, 06:36:00 UTC
Merge pull request #491 from braun-steven/feature/logsumexp
Merge pull request #491 from braun-steven/feature/logsumexp
Tip revision: a26ffe0
mxnet_backend.py
try:
import mxnet as mx
from mxnet import numpy as np
except ImportError as error:
message = (
"Cannot import MXNet.\n"
"To use TensorLy with the MXNet backend, "
"you must first install MXNet!"
)
raise ImportError(message) from error
import warnings
import numpy
from .core import Backend, backend_basic_math, backend_array
mx.npx.set_np()
class MxnetBackend(Backend, backend_name="mxnet"):
@staticmethod
def context(tensor):
return {"dtype": tensor.dtype}
@staticmethod
def tensor(data, dtype=None, **kwargs):
if dtype is None and isinstance(data, numpy.ndarray):
dtype = data.dtype
return np.array(data, dtype=dtype)
@staticmethod
def is_tensor(tensor):
return isinstance(tensor, np.ndarray)
@staticmethod
def to_numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor.asnumpy()
elif isinstance(tensor, numpy.ndarray):
return tensor
else:
return numpy.array(tensor)
@staticmethod
def ndim(tensor):
return tensor.ndim
@staticmethod
def clip(tensor, a_min=None, a_max=None):
return np.clip(tensor, a_min, a_max)
@staticmethod
def conj(x, *args, **kwargs):
"""WARNING: IDENTITY FUNCTION (does nothing)
This backend currently does not support complex tensors
"""
return x
def svd(self, X, full_matrices=True):
# MXNet doesn't provide an option for full_matrices=True
if full_matrices is True:
ctx = self.context(X)
X = self.to_numpy(X)
if X.shape[0] > X.shape[1]:
U, S, V = numpy.linalg.svd(X.T)
U, S, V = V.T, S, U.T
else:
U, S, V = numpy.linalg.svd(X)
U = self.tensor(U, **ctx)
S = self.tensor(S, **ctx)
V = self.tensor(V, **ctx)
return U, S, V
if X.shape[0] > X.shape[1]:
U, S, V = np.linalg.svd(X.T)
U, S, V = V.T, S, U.T
else:
U, S, V = np.linalg.svd(X)
return U, S, V
@staticmethod
def lstsq(a, b):
x, residuals, _, _ = np.linalg.lstsq(a, b, rcond=None)
return x, residuals
@staticmethod
def logsumexp(x, axis=0):
max_x = np.max(x, axis=axis, keepdims=True)
return np.squeeze(
max_x + np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)),
axis=axis,
)
for name in (
backend_basic_math
+ backend_array
+ [
"int64",
"int32",
"float64",
"float32",
"pi",
"e",
"inf",
"nan",
"moveaxis",
"copy",
"transpose",
"arange",
"trace",
"concatenate",
"max",
"sign",
"flip",
"mean",
"sum",
"argmin",
"argmax",
"stack",
"diag",
"log2",
"tensordot",
"exp",
"argsort",
"sort",
"dot",
"shape",
]
):
MxnetBackend.register_method(name, getattr(np, name))
for name in ["solve", "qr", "eigh"]:
MxnetBackend.register_method(name, getattr(np.linalg, name))