https://github.com/tensorly/tensorly
backend_specific.py
from functools import wraps
from .. import get_backend, available_backend_names
from inspect import getfullargspec
def register_backend_specific(backend, backend_func):
"""A decorator to create a backend specific function.
Parameters
----------
backend : str
String that specifies which backend implementation to overload. Must be an entry
in the ``tensorly.available_backend_names``-list.
backend_func : callable
Function implementation for the given backend. The function signature must match
the signature of the backend agnostic implementation exactly (including which
arguments are keyword-/positional-only and the default arguments) to ensure that
the documentation is correct.
Returns
-------
func
Decorator that creates a backend specific function
Examples
--------
>>> from tensorly.backend import backend_context
>>> from tensorly.utils.backend_specific import register_backend_specific
>>> def _numpy_func(x):
... return 2
>>> @register_backend_specific("numpy", _numpy_func)
... def func(x):
... return x
>>> with backend_context("numpy"):
... print(func(5))
2
>>> with backend_context("pytorch"):
... print(func(5))
5
"""
if backend not in available_backend_names:
raise ValueError(
f"Unknown backend name '{backend}', "
+ f"known backends are {available_backend_names}"
)
def wrapper(f):
if getfullargspec(f) != getfullargspec(backend_func):
raise TypeError(
"Signature of backend spesific implementation does not match the "
+ "signature of the backend agnostic implementation. All arguments and "
+ "default values must match, otherwise the documentation will be"
+ "incorrect."
)
@wraps(f)
def wrapped(*args, **kwargs):
if get_backend() == backend:
return backend_func(*args, **kwargs)
return f(*args, **kwargs)
return wrapped
return wrapper