Revision c4b2c08fcdc2664e886be357161da815abb8f2bc authored by Jean Kossaifi on 27 August 2017, 20:04:05 UTC, committed by Jean Kossaifi on 27 August 2017, 20:04:05 UTC
1 parent 72f040b
__init__.py
import sys
import importlib
import os
# Set the default backend
default_backend = 'mxnet'
try:
if _BACKEND is None:
_BACKEND = os.environ.get('TENSORLY_BACKEND', default_backend)
except NameError:
_BACKEND = os.environ.get('TENSORLY_BACKEND', default_backend)
def set_backend(backend_name):
global _BACKEND
_BACKEND = backend_name
# reloads tensorly.backend
importlib.reload(backend)
# reload from .backend import * (e.g. tensorly.tensor)
globals().update(
{fun: getattr(backend, fun) for n in backend.__all__} if hasattr(backend, '__all__')
else
{k: v for (k, v) in backend.__dict__.items() if not k.startswith('_')
})
from .backend import *
from .base import unfold, fold
from .base import tensor_to_vec, vec_to_tensor
from .base import partial_unfold, partial_fold
from .base import partial_tensor_to_vec, partial_vec_to_tensor
from .kruskal_tensor import kruskal_to_tensor, kruskal_to_unfolded, kruskal_to_vec
from .tucker_tensor import tucker_to_tensor, tucker_to_unfolded, tucker_to_vec
__version__ = '0.2.0'
Computing file changes ...