https://github.com/tensorly/tensorly
Raw File
Tip revision: a26ffe09a1468b328bd0e26e432e5fb1cc334efc authored by Jean Kossaifi on 08 March 2023, 06:36:00 UTC
Merge pull request #491 from braun-steven/feature/logsumexp
Tip revision: a26ffe0
jax_backend.py
from packaging.version import Version

try:
    import jax
    from jax.config import config

    config.update("jax_enable_x64", True)
    import jax.numpy as np
    import jax.scipy.special
except ImportError as error:
    message = (
        "Impossible to import Jax.\n"
        "To use TensorLy with the Jax backend, "
        "you must first install Jax!"
    )
    raise ImportError(message) from error

import numpy
import copy

from .core import (
    Backend,
    backend_types,
    backend_basic_math,
    backend_array,
)


class JaxBackend(Backend, backend_name="jax"):
    @staticmethod
    def context(tensor):
        return {"dtype": tensor.dtype}

    @staticmethod
    def tensor(data, dtype=None, **kwargs):
        return np.array(data, dtype=dtype)

    @staticmethod
    def is_tensor(tensor):
        return isinstance(tensor, np.ndarray)

    @staticmethod
    def to_numpy(tensor):
        return numpy.asarray(tensor)

    def copy(self, tensor):
        # See https://github.com/tensorly/tensorly/pull/397
        # and https://github.com/google/jax/issues/3473
        return self.tensor(tensor.copy(), **self.context(tensor))
        # return copy.copy(tensor)

    @staticmethod
    def ndim(tensor):
        return tensor.ndim

    @staticmethod
    def lstsq(a, b):
        x, residuals, _, _ = np.linalg.lstsq(a, b, rcond=None, numpy_resid=True)
        return x, residuals

    def kr(self, matrices, weights=None, mask=None):
        n_columns = matrices[0].shape[1]
        n_factors = len(matrices)

        start = ord("a")
        common_dim = "z"
        target = "".join(chr(start + i) for i in range(n_factors))
        source = ",".join(i + common_dim for i in target)
        operation = source + "->" + target + common_dim

        if weights is not None:
            matrices = [
                m if i else m * self.reshape(weights, (1, -1))
                for i, m in enumerate(matrices)
            ]

        m = mask.reshape((-1, 1)) if mask is not None else 1
        return np.einsum(operation, *matrices).reshape((-1, n_columns)) * m

    @staticmethod
    def logsumexp(tensor, axis=0):
        return jax.scipy.special.logsumexp(tensor, axis=axis)


for name in (
    backend_types
    + backend_basic_math
    + backend_array
    + [
        "nan",
        "moveaxis",
        "transpose",
        "arange",
        "flip",
        "trace",
        "kron",
        "concatenate",
        "max",
        "mean",
        "sum",
        "argmin",
        "argmax",
        "stack",
        "sign",
        "conj",
        "diag",
        "clip",
        "log2",
        "tensordot",
        "argsort",
        "sort",
        "dot",
        "shape",
    ]
):
    JaxBackend.register_method(name, getattr(np, name))

for name in ["solve", "qr", "svd", "eigh"]:
    JaxBackend.register_method(name, getattr(np.linalg, name))

if Version(jax.__version__) >= Version("0.3.0"):

    def index_update(tensor, indices, values):
        return tensor.at[indices].set(values)

    JaxBackend.register_method("index_update", index_update)
else:
    JaxBackend.register_method(name, getattr(jax.ops, name))

for name in ["gamma"]:
    JaxBackend.register_method(name, getattr(jax.random, name))
back to top