Revision c37a55579df7095a5ee66a62a5b99281bdf01418 authored by TUNA Caglayan on 01 September 2021, 09:32:34 UTC, committed by TUNA Caglayan on 01 September 2021, 09:32:34 UTC
1 parent 4772692
Raw File
cupy_backend.py
try:
    import cupy as cp
except ImportError as error:
    message = ('Impossible to import cupy.\n'
               'To use TensorLy with the cupy backend, '
               'you must first install cupy!')
    raise ImportError(message) from error

import warnings
import numpy as np

from .core import Backend


class CupyBackend(Backend):
    backend_name = 'cupy'

    @staticmethod
    def context(tensor):
        return {'dtype': tensor.dtype}

    @staticmethod
    def tensor(data, dtype=cp.float32):
        return cp.array(data, dtype=dtype)

    @staticmethod
    def is_tensor(tensor):
        return isinstance(tensor, cp.ndarray)

    @staticmethod
    def to_numpy(tensor):
        if isinstance(tensor, cp.ndarray):
            return cp.asnumpy(tensor)
        return tensor

    @staticmethod
    def shape(tensor):
        return tensor.shape

    @staticmethod
    def ndim(tensor):
        return tensor.ndim

    @staticmethod
    def clip(tensor, a_min=None, a_max=None):
        return cp.clip(tensor, a_min, a_max)

    @staticmethod
    def lstsq(a, b):
        x, residuals, _, _ = cp.linalg.lstsq(a, b, rcond=None)
        return x, residuals

    @staticmethod
    def sort(tensor, axis, descending = False):
        if descending:
            return cp.flip(cp.sort(tensor, axis=axis), axis = axis)
        else:
            return cp.sort(tensor, axis=axis)


for name in ['float64', 'float32', 'int64', 'int32', 'complex128', 'complex64', 'reshape', 'moveaxis',
             'transpose', 'copy', 'ones', 'zeros', 'zeros_like', 'eye', 'trace', 'any',
             'arange', 'where', 'dot', 'kron', 'concatenate', 'max', 'flip', 'matmul',
             'min', 'all', 'mean', 'sum', 'cumsum', 'prod', 'sign', 'abs', 'sqrt', 'stack',
             'conj', 'diag', 'einsum', 'log2', 'tensordot']:
    CupyBackend.register_method(name, getattr(cp, name))

for name in ['svd', 'qr', 'eigh', 'solve']:
    CupyBackend.register_method(name, getattr(cp.linalg, name))
back to top