swh:1:snp:4e3e7077647a709f15b8c1b32ce7100175d0580b
Raw File
Tip revision: bdc11743a1e8d283d4b9b12cc1b17930c12b0518 authored by Jean Kossaifi on 07 December 2020, 19:05:18 UTC
DOC: fix class documentation
Tip revision: bdc1174
mxnet_backend.py
try:
    import mxnet as mx
except ImportError as error:
    message = ('Impossible to import MXNet.\n'
               'To use TensorLy with the MXNet backend, '
               'you must first install MXNet!')
    raise ImportError(message) from error

import math
import warnings

import numpy
from mxnet import nd
from mxnet.ndarray import reshape, dot, transpose, stack

from .core import Backend


class MxnetBackend(Backend):
    backend_name = 'mxnet'

    @staticmethod
    def context(tensor):
        return {'ctx': tensor.context, 'dtype': tensor.dtype}

    @staticmethod
    def tensor(data, ctx=mx.cpu(), dtype=numpy.float32):
        if dtype is None and isinstance(data, numpy.ndarray):
            dtype = data.dtype
        return nd.array(data, ctx=ctx, dtype=dtype)

    @staticmethod
    def is_tensor(tensor):
        return isinstance(tensor, nd.NDArray)

    @staticmethod
    def to_numpy(tensor):
        if isinstance(tensor, nd.NDArray):
            return tensor.asnumpy()
        elif isinstance(tensor, numpy.ndarray):
            return tensor
        else:
            return numpy.array(tensor)

    @staticmethod
    def shape(tensor):
        return tensor.shape

    @staticmethod
    def ndim(tensor):
        return tensor.ndim

    @staticmethod
    def reshape(tensor, shape):
        if not shape:
            shape = [1]
        return nd.reshape(tensor, shape)

    def solve(self, matrix1, matrix2):
        ctx = self.context(matrix1)
        matrix1 = self.to_numpy(matrix1)
        matrix2 = self.to_numpy(matrix2)
        res = numpy.linalg.solve(matrix1, matrix2)
        return self.tensor(res, **ctx)

    @staticmethod
    def min(tensor, *args, **kwargs):
        if isinstance(tensor, nd.NDArray):
            return nd.min(tensor, *args, **kwargs).asscalar()
        else:
            return numpy.min(tensor, *args, **kwargs)

    @staticmethod
    def max(tensor, *args, **kwargs):
        if isinstance(tensor, nd.NDArray):
            return nd.max(tensor, *args, **kwargs).asscalar()
        else:
            return numpy.max(tensor, *args, **kwargs)

    @staticmethod
    def argmax(data=None, axis=None):
        res = nd.argmax(data, axis)
        if res.shape == (1,):
            return res.astype('int32').asscalar()
        else:
            return res

    @staticmethod
    def argmin(data=None, axis=None):
        res = nd.argmin(data, axis)
        if res.shape == (1,):
            return res.astype('int32').asscalar()
        else:
            return res

    @staticmethod
    def abs(tensor, **kwargs):
        if isinstance(tensor, nd.NDArray):
            return nd.abs(tensor, **kwargs)
        else:
            return numpy.abs(tensor, **kwargs)

    @staticmethod
    def norm(tensor, order=2, axis=None):
        # handle difference in default axis notation
        if axis is None:
            axis = ()

        if order == 'inf':
            res = nd.max(nd.abs(tensor), axis=axis)
        elif order == 1:
            res = nd.sum(nd.abs(tensor), axis=axis)
        elif order == 2:
            res = nd.sqrt(nd.sum(tensor**2, axis=axis))
        else:
            res = nd.sum(nd.abs(tensor)**order, axis=axis)**(1 / order)

        if res.shape == (1,):
            return res.asscalar()

        return res

    def qr(self, matrix):
        s1, s2 = matrix.shape
        if s2 < s1:
            # NOTE - should be replaced with geqrf when available
            Q, L = nd.linalg.gelqf(matrix.T)
            return Q.T, L.T

        warnings.warn('This version of MXNet does not include the linear '
                      'algebra function gelqf(). Substituting with numpy.')
        ctx = self.context(matrix)
        Q, R = numpy.linalg.qr(self.to_numpy(matrix))
        return self.tensor(Q, **ctx), self.tensor(R, **ctx)

    @staticmethod
    def clip(tensor, a_min=None, a_max=None, indlace=False):
        if a_min is not None and a_max is not None:
            if indlace:
                nd.max(nd.min(tensor, a_max, out=tensor), a_min, out=tensor)
            else:
                tensor = nd.maximum(nd.minimum(tensor, a_max), a_min)
        elif min is not None:
            if indlace:
                nd.max(tensor, a_min, out=tensor)
            else:
                tensor = nd.maximum(tensor, a_min)
        elif max is not None:
            if indlace:
                nd.min(tensor, a_max, out=tensor)
            else:
                tensor = nd.minimum(tensor, a_max)
        return tensor

    @staticmethod
    def all(tensor):
        return nd.sum(tensor != 0).asscalar()

    @staticmethod
    def conj(x, *args, **kwargs):
        """WARNING: IDENTITY FUNCTION (does nothing)

            This backend currently does not support complex tensors
        """
        return x

    def moveaxis(self, tensor, source, target):
        axes = list(range(self.ndim(tensor)))
        if source < 0: source = axes[source]
        if target < 0: target = axes[target]
        try:
            axes.pop(source)
        except IndexError:
            raise ValueError('Source should verify 0 <= source < tensor.ndim'
                             'Got %d' % source)
        try:
            axes.insert(target, source)
        except IndexError:
            raise ValueError('Destination should verify 0 <= destination < tensor.ndim'
                             'Got %d' % target)
        return transpose(tensor, axes)

    @staticmethod
    def mean(tensor, axis=None, **kwargs):
        if axis is None:
            axis = ()
        res = nd.mean(tensor, axis=axis, **kwargs)
        if res.shape == (1,):
            return res.asscalar()
        else:
            return res

    @staticmethod
    def sum(tensor, axis=None, **kwargs):
        if axis is None:
            axis = ()
        res = nd.sum(tensor, axis=axis, **kwargs)
        if res.shape == (1,):
            return res.asscalar()
        else:
            return res

    @staticmethod
    def sqrt(tensor, *args, **kwargs):
        if isinstance(tensor, nd.NDArray):
            return nd.sqrt(tensor, *args, **kwargs)
        else:
            return math.sqrt(tensor)

    @staticmethod
    def copy(tensor):
        return tensor.copy()

    @staticmethod
    def concatenate(tensors, axis):
        return nd.concat(*tensors, dim=axis)

    def stack(self, arrays, axis=0):
        res = nd.stack(*arrays, axis=axis)
        # Ugly fix for stacking zero-order tensors that are of shape (1, ) in MXNet
        if self.ndim(res) == 2 and self.shape(res) == (len(arrays), 1):
            return res.squeeze()
        return res

    def symeig_svd(self, matrix, n_eigenvecs=None, **kwargs):
        """Computes a truncated SVD on `matrix` using symeig

            Uses symeig on matrix.T.dot(matrix) or its transpose

        Parameters
        ----------
        matrix : 2D-array
        n_eigenvecs : int, optional, default is None
            if specified, number of eigen[vectors-values] to return
        **kwargs : optional
            kwargs are used to absorb the difference of parameters among the other SVD functions

        Returns
        -------
        U : 2D-array
            of shape (matrix.shape[0], n_eigenvecs)
            contains the right singular vectors
        S : 1D-array
            of shape (n_eigenvecs, )
            contains the singular values of `matrix`
        V : 2D-array
            of shape (n_eigenvecs, matrix.shape[1])
            contains the left singular vectors
        """
        # Check that matrix is... a matrix!
        if self.ndim(matrix) != 2:
            raise ValueError('matrix be a matrix. matrix.ndim is %d != 2'
                             % self.ndim(matrix))

        dim_1, dim_2 = self.shape(matrix)
        if dim_1 <= dim_2:
            min_dim = dim_1
            max_dim = dim_2
        else:
            min_dim = dim_2
            max_dim = dim_1

        if n_eigenvecs is None:
            n_eigenvecs = max_dim

        if min_dim <= n_eigenvecs:
            if n_eigenvecs > max_dim:
                warnings.warn('Trying to compute SVD with n_eigenvecs={0}, which '
                              'is larger than max(matrix.shape)={1}. Setting '
                              'n_eigenvecs to {1}'.format(n_eigenvecs, max_dim))
                n_eigenvecs = max_dim
            # we compute decomposition on the largest of the two to keep more eigenvecs
            dim_1, dim_2 = dim_2, dim_1

        if dim_1 < dim_2:
            U, S = nd.linalg.syevd(dot(matrix, transpose(matrix)))
            S = self.sqrt(S)
            V = dot(transpose(matrix), U / reshape(S, (1, -1)))
        else:
            V, S = nd.linalg.syevd(dot(transpose(matrix), matrix))
            S = self.sqrt(S)
            U = dot(matrix, V) / reshape(S, (1, -1))

        U, S, V = U[:, ::-1], S[::-1], transpose(V)[::-1, :]
        return U[:, :n_eigenvecs], S[:n_eigenvecs], V[:n_eigenvecs, :]

    @property
    def SVD_FUNS(self):
        return {'numpy_svd': self.partial_svd,
                'symeig_svd': self.symeig_svd}
    
    @staticmethod
    def sort(tensor, axis, descending = False):
        if descending:
            is_ascend = False
        else:
            is_ascend = True

        return mx.ndarray.sort(tensor, axis=axis, is_ascend = is_ascend)
    

for name in ['float64', 'float32', 'int64', 'int32']:
    MxnetBackend.register_method(name, getattr(numpy, name))

for name in ['arange', 'zeros', 'zeros_like', 'ones', 'eye', 'dot',
             'transpose', 'where', 'sign', 'prod', 'diag']:
    MxnetBackend.register_method(name, getattr(nd, name))
back to top