Revision c6d4fdb44a004f804c489c8cb1739ba17d0b0939 authored by TUNA Caglayan on 20 October 2021, 10:09:49 UTC, committed by TUNA Caglayan on 20 October 2021, 10:09:49 UTC
1 parent 2b16764
Raw File
mxnet_backend.py
try:
    import mxnet as mx
    from mxnet import numpy as np
except ImportError as error:
    message = ('Cannot import MXNet.\n'
               'To use TensorLy with the MXNet backend, '
               'you must first install MXNet!')
    raise ImportError(message) from error

import warnings
import numpy
from .core import Backend

mx.npx.set_np()


class MxnetBackend(Backend):
    backend_name = 'mxnet'

    @staticmethod
    def context(tensor):
        return {'dtype': tensor.dtype}

    @staticmethod
    def tensor(data, dtype=None):
        if dtype is None and isinstance(data, numpy.ndarray):
            dtype = data.dtype
        return np.array(data, dtype=dtype)

    @staticmethod
    def is_tensor(tensor):
        return isinstance(tensor, np.ndarray)

    @staticmethod
    def to_numpy(tensor):
        if isinstance(tensor, np.ndarray):
            return tensor.asnumpy()
        elif isinstance(tensor, numpy.ndarray):
            return tensor
        else:
            return numpy.array(tensor)

    @staticmethod
    def shape(tensor):
        return tensor.shape

    @staticmethod
    def ndim(tensor):
        return tensor.ndim

    @staticmethod
    def dot(a, b):
        return np.dot(a, b)

    @staticmethod
    def clip(tensor, a_min=None, a_max=None):
        return np.clip(tensor, a_min, a_max)

    @staticmethod
    def conj(x, *args, **kwargs):
        """WARNING: IDENTITY FUNCTION (does nothing)

            This backend currently does not support complex tensors
        """
        return x

    def svd(self, X, full_matrices=True):
        # MXNet doesn't provide an option for full_matrices=True
        if full_matrices is True:
            ctx = self.context(X)
            X = self.to_numpy(X)

            if X.shape[0] > X.shape[1]:
                U, S, V = numpy.linalg.svd(X.T)

                U, S, V = V.T, S, U.T
            else:
                U, S, V = numpy.linalg.svd(X)

            U = self.tensor(U, **ctx)
            S = self.tensor(S, **ctx)
            V = self.tensor(V, **ctx)

            return U, S, V

        if X.shape[0] > X.shape[1]:
            U, S, V = np.linalg.svd(X.T)

            U, S, V = V.T, S, U.T
        else:
            U, S, V = np.linalg.svd(X)
        
        return U, S, V

    @staticmethod
    def lstsq(a, b):
        x, residuals, _, _ = np.linalg.lstsq(a, b, rcond=None)
        return x, residuals

    @staticmethod
    def sort(tensor, axis, descending = False):
        if descending:
            return np.flip(np.sort(tensor, axis=axis), axis = axis)
        else:
            return np.sort(tensor, axis=axis)

    @staticmethod
    def argsort(tensor, axis, descending = False):
        if descending:
            return np.argsort(-1 * tensor, axis=axis)
        else:
            return np.argsort(tensor, axis=axis)

for name in ['int64', 'int32', 'float64', 'float32', 'reshape', 'moveaxis',
             'where', 'copy', 'transpose', 'arange', 'ones', 'zeros', 'trace', 'any',
             'zeros_like', 'eye', 'concatenate', 'max', 'min', 'flip', 'matmul',
             'all', 'mean', 'sum', 'cumsum', 'count_nonzero',  'prod', 'sign', 'abs', 'sqrt', 'argmin',
             'argmax', 'stack', 'diag', 'einsum', 'log2', 'tensordot', 'sin', 'cos']:
    MxnetBackend.register_method(name, getattr(np, name))

for name in ['solve', 'qr', 'eigh']:
    MxnetBackend.register_method(name, getattr(np.linalg, name))
back to top