Revision 1b5f59ee166415f66c9e7b9e730ace0bed263f4a authored by Jean Kossaifi on 30 December 2022, 15:24:18 UTC, committed by Jean Kossaifi on 30 December 2022, 15:24:18 UTC
1 parent d277364
__init__.py
"""
The :mod:`tensorly.tenalg` module contains utilities for Tensor Algebra
operations such as khatri-rao or kronecker product, n-mode product, etc.
"""
import sys
import importlib
import threading
from ..backend import BackendManager, dynamically_dispatched_class_attribute
from .base_tenalg import TenalgBackend
from .svd import SVD_FUNS, svd_interface, truncated_svd
class TenalgBackendManager(BackendManager):
_functions = [
"mode_dot",
"multi_mode_dot",
"kronecker",
"khatri_rao",
"inner",
"outer",
"batched_outer",
"higher_order_moment",
"_tt_matrix_to_tensor",
"unfolding_dot_khatri_rao",
"tensordot",
]
_attributes = []
available_backend_names = ["core", "einsum"]
_default_backend = "core"
_loaded_backends = dict()
_backend = None
_THREAD_LOCAL_DATA = threading.local()
_ENV_DEFAULT_VAR = "TENSORLY_TENALG_BACKEND"
@classmethod
def use_dynamic_dispatch(cls):
# Define class methods and attributes that dynamically dispatch to the backend
for name in cls._functions:
try:
delattr(cls, name)
except AttributeError:
pass
setattr(
cls,
name,
staticmethod(
cls.dispatch_backend_method(
name, getattr(cls.current_backend(), name)
)
),
)
for name in cls._attributes:
try:
delattr(cls, name)
except AttributeError:
pass
setattr(cls, name, dynamically_dispatched_class_attribute(name))
@classmethod
def load_backend(cls, backend_name):
"""Registers a new backend by importing the corresponding module
and adding the correspond `Backend` class in Backend._LOADED_BACKEND
under the key `backend_name`
Parameters
----------
backend_name : str, name of the backend to load
Raises
------
ValueError
If `backend_name` does not correspond to one listed
in `_KNOWN_BACKEND`
"""
if backend_name not in cls.available_backend_names:
msg = f"Unknown backend name {backend_name!r}, known backends are {cls.available_backend_names}"
raise ValueError(msg)
if backend_name not in TenalgBackend._available_tenalg_backends:
importlib.import_module(f"tensorly.tenalg.{backend_name}_tenalg")
if backend_name in TenalgBackend._available_tenalg_backends:
backend = TenalgBackend._available_tenalg_backends[backend_name]()
# backend = getattr(module, )()
cls._loaded_backends[backend_name] = backend
return backend
# Initialise the backend to the default one
TenalgBackendManager.initialize_backend()
TenalgBackendManager.use_dynamic_dispatch()
sys.modules[__name__].__class__ = TenalgBackendManager
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...