Revision c6d4fdb44a004f804c489c8cb1739ba17d0b0939 authored by TUNA Caglayan on 20 October 2021, 10:09:49 UTC, committed by TUNA Caglayan on 20 October 2021, 10:09:49 UTC
1 parent 2b16764
Raw File
__init__.py
import warnings
from .core import Backend
import importlib
import os
import sys
import threading
from contextlib import contextmanager
import inspect

_DEFAULT_BACKEND = 'numpy'
_KNOWN_BACKENDS = {'numpy': 'NumpyBackend',
                   'mxnet': 'MxnetBackend',
                   'pytorch': 'PyTorchBackend',
                   'tensorflow': 'TensorflowBackend',
                   'cupy': 'CupyBackend',
                   'jax': 'JaxBackend'}

_LOADED_BACKENDS = {}
_LOCAL_STATE = threading.local()

def initialize_backend():
    """Initialises the backend

    1) retrieve the default backend name from the `TENSORLY_BACKEND` environment variable
        if not found, use _DEFAULT_BACKEND
    2) sets the backend to the retrieved backend name
    """
    backend_name = os.environ.get('TENSORLY_BACKEND', _DEFAULT_BACKEND)
    if backend_name not in _KNOWN_BACKENDS:
        msg = ("TENSORLY_BACKEND should be one of {}, got {}. Defaulting to {}'").format(
            ', '.join(map(repr, _KNOWN_BACKENDS)),
            backend_name, _DEFAULT_BACKEND)
        warnings.warn(msg, UserWarning)
        backend_name = _DEFAULT_BACKEND

    set_backend(backend_name, local_threadsafe=False)

def register_backend(backend_name):
    """Registers a new backend by importing the corresponding module 
        and adding the correspond `Backend` class in Backend._LOADED_BACKEND
        under the key `backend_name`
    
    Parameters
    ----------
    backend_name : str, name of the backend to load
    
    Raises
    ------
    ValueError
        If `backend_name` does not correspond to one listed
            in `_KNOWN_BACKEND`
    """
    if backend_name in _KNOWN_BACKENDS:
        module = importlib.import_module('tensorly.backend.{0}_backend'.format(backend_name))
        backend = getattr(module, _KNOWN_BACKENDS[backend_name])()
        _LOADED_BACKENDS[backend_name] = backend
    else:
        msg = "Unknown backend name {0!r}, known backends are [{1}]".format(
            backend_name, ', '.join(map(repr, _KNOWN_BACKENDS)))
        raise ValueError(msg)

def set_backend(backend, local_threadsafe=False):
    """Changes the backend to the specified one
    
    Parameters
    ----------
    backend : tensorly.Backend or str
        name of the backend to load or Backend Class
    local_threadsafe : bool, optional, default is False
        If False, set the backend as default for all threads        
    """
    if not isinstance(backend, Backend):
        # Backend is a string
        if backend not in _LOADED_BACKENDS:
            register_backend(backend)

        backend = _LOADED_BACKENDS[backend]

    # Set the backend
    _LOCAL_STATE.backend = backend

    if not local_threadsafe:
        global _DEFAULT_BACKEND
        _DEFAULT_BACKEND = backend.backend_name

def get_backend():
    """Returns the name of the current backend
    """
    return _get_backend_method('backend_name')

def _get_backend_method(key):
    try:
        return getattr(_LOCAL_STATE.backend, key)
    except AttributeError:
        return getattr(_LOADED_BACKENDS[_DEFAULT_BACKEND], key)

def _get_backend_dir():
    return [k for k in dir(_LOCAL_STATE.backend) if not k.startswith('_')]

@contextmanager
def backend_context(backend, local_threadsafe=False):
    """Context manager to set the backend for TensorLy.

    Parameters
    ----------
    backend : {'numpy', 'mxnet', 'pytorch', 'tensorflow', 'cupy'}
        The name of the backend to use. Default is 'numpy'.
    local_threadsafe : bool, optional
        If True, the backend will not become the default backend for all threads.
        Note that this only affects threads where the backend hasn't already
        been explicitly set. If False (default) the backend is set for the
        entire session.

    Examples
    --------
    Set the backend to numpy globally for this thread:

    >>> import tensorly as tl
    >>> tl.set_backend('numpy')
    >>> with tl.backend_context('pytorch'):
    ...     pass
    """
    _old_backend = get_backend()
    set_backend(backend, local_threadsafe=local_threadsafe)
    try:
        yield
    finally:
        set_backend(_old_backend)

def override_module_dispatch(module_name, getter_fun, dir_fun):
    """Override the module's dispatch mechanism

        In Python >= 3.7, we use module's __getattr__ and __dir__
        On older versions, we override the sys.module[__name__].__class__
    """
    if sys.version_info >= (3, 7, 0):
        sys.modules[module_name].__getattr__ = getter_fun
        sys.modules[module_name].__dir__ = dir_fun

    else:
        import types

        class BackendAttributeModuleType(types.ModuleType):
            """A module type to dispatch backend generic attributes."""
            def __getattr__(self, key):
                return getter_fun(key)

            def __dir__(self):
                out = set(super().__dir__())
                out.update({k for k in dir(_LOCAL_STATE.backend) if not k.startswith('_')})
                return list(out)

        sys.modules[module_name].__class__ = BackendAttributeModuleType


def dispatch(method):
    """Create a dispatched function from a generic backend method."""
    name = method.__name__

    def inner(*args, **kwargs):
        return _get_backend_method(name)(*args, **kwargs)

    # We don't use `functools.wraps` here because some of the dispatched
    # methods include the backend (`self`) as a parameter. Instead we manually
    # copy over the needed information, and filter the signature for `self`.
    for attr in ['__module__', '__name__', '__qualname__', '__doc__',
                 '__annotations__']:
        try:
            setattr(inner, attr, getattr(method, attr))
        except AttributeError:
            pass

    sig = inspect.signature(method)
    if 'self' in sig.parameters:
        parameters = [v for k, v in sig.parameters.items() if k != 'self']
        sig = sig.replace(parameters=parameters)
    inner.__signature__ = sig

    return inner


# Generic methods, exposed as part of the public API
check_random_state = dispatch(Backend.check_random_state)
randn = dispatch(Backend.randn)
context = dispatch(Backend.context)
tensor = dispatch(Backend.tensor)
is_tensor = dispatch(Backend.is_tensor)
shape = dispatch(Backend.shape)
ndim = dispatch(Backend.ndim)
to_numpy = dispatch(Backend.to_numpy)
copy = dispatch(Backend.copy)
concatenate = dispatch(Backend.concatenate)
stack = dispatch(Backend.stack)
reshape = dispatch(Backend.reshape)
transpose = dispatch(Backend.transpose)
moveaxis = dispatch(Backend.moveaxis)
arange = dispatch(Backend.arange)
ones = dispatch(Backend.ones)
zeros = dispatch(Backend.zeros)
zeros_like = dispatch(Backend.zeros_like)
eye = dispatch(Backend.eye)
where = dispatch(Backend.where)
clip = dispatch(Backend.clip)
max = dispatch(Backend.max)
min = dispatch(Backend.min)
argmax = dispatch(Backend.argmax)
argmin = dispatch(Backend.argmin)
all = dispatch(Backend.all)
mean = dispatch(Backend.mean)
sum = dispatch(Backend.sum)
prod = dispatch(Backend.prod)
sign = dispatch(Backend.sign)
abs = dispatch(Backend.abs)
sqrt = dispatch(Backend.sqrt)
norm = dispatch(Backend.norm)
dot = dispatch(Backend.dot)
matmul = dispatch(Backend.matmul)
kron = dispatch(Backend.kron)
solve = dispatch(Backend.solve)
lstsq = dispatch(Backend.lstsq)
qr = dispatch(Backend.qr)
kr = dispatch(Backend.kr)
partial_svd = dispatch(Backend.partial_svd)
randomized_svd = dispatch(Backend.randomized_svd)
randomized_range_finder = dispatch(Backend.randomized_range_finder)
sort = dispatch(Backend.sort)
conj = dispatch(Backend.conj)
eps = dispatch(Backend.eps)
finfo = dispatch(Backend.finfo)
index = Backend.index
index_update = dispatch(Backend.index_update)
log2 = dispatch(Backend.log2)
sin = dispatch(Backend.sin)
cos = dispatch(Backend.cos)

# Initialise the backend to the default one
initialize_backend()

# dispatch non-callables (e.g. dtypes, index)
override_module_dispatch(__name__, 
                         _get_backend_method,
                         _get_backend_dir)
back to top