Revision 1295ccb09626f89f20d0c0183d618f96b4833bf1 authored by Jean Kossaifi on 08 May 2018, 21:04:53 UTC, committed by Jean Kossaifi on 08 May 2018, 22:15:23 UTC
1 parent c729db7
Raw File
pytorch_backend.py
"""
Core tensor operations with PyTorch.
"""

# Author: Jean Kossaifi
# License: BSD 3 clause


# First check whether PyTorch is installed
try:
    import torch
except ImportError as error:
    message = ('Impossible to import PyTorch.\n'
               'To use TensorLy with the PyTorch backend, '
               'you must first install PyTorch!')
    raise ImportError(message) from error

from distutils.version import LooseVersion

if LooseVersion(torch.__version__) < LooseVersion('0.4.0'):
    raise ImportError('You are using version="{}" of PyTorch.'
                      'Please update to "0.4.0" or higher.'.format(torch.__version__))

import numpy
import scipy.linalg
import scipy.sparse.linalg
from numpy import testing
from . import numpy_backend

from torch import ones, zeros, zeros_like, reshape
from torch import max, min, where
from torch import sum, mean, abs, sqrt, sign, prod, sqrt
from torch import matmul as dot
from torch import qr

# Order 0 tensor, mxnet....
from math import sqrt as scalar_sqrt

# Equivalent functions in pytorch 
maximum = max

def context(tensor):
    """Returns the context of a tensor
    """
    return {'dtype':tensor.dtype, 'device':tensor.device, 'requires_grad':tensor.requires_grad}


def tensor(data, dtype=torch.float32, device='cpu', requires_grad=False):
    """Tensor class
    """
    if isinstance(data, numpy.ndarray):
        return torch.tensor(data.copy(), dtype=dtype, device=device, requires_grad=requires_grad)
    return torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)

def to_numpy(tensor):
    """Convert a tensor to numpy format

    Parameters
    ----------
    tensor : Tensor

    Returns
    -------
    ndarray
    """
    if torch.is_tensor(tensor) and tensor.cuda:
        return tensor.cpu().numpy()
    elif torch.is_tensor(tensor):
        return tensor.numpy()
    if isinstance(tensor, numpy.ndarray):
        return tensor

    try:
        return numpy.array(tensor)
    except ValueError:
        raise ValueError('Could not convert object of type {} into a Numpy '
                         'NDArray'.format(type(tensor)))

def assert_array_equal(a, b, **kwargs):
    testing.assert_array_equal(to_numpy(a), to_numpy(b), **kwargs)

def assert_array_almost_equal(a, b, decimal=3, **kwargs):
    testing.assert_array_almost_equal(to_numpy(a), to_numpy(b), decimal=decimal, **kwargs)

assert_raises = testing.assert_raises

def assert_equal(actual, desired, err_msg='', verbose=True):
    if isinstance(actual, torch.Tensor):
        actual = to_numpy(actual)
    if isinstance(desired, torch.Tensor):
        desired = to_numpy(desired)
    testing.assert_equal(actual, desired, err_msg=err_msg, verbose=verbose)

assert_ = testing.assert_

def shape(tensor):
    return tensor.size()

def ndim(tensor):
    return tensor.dim()

def arange(start, stop=None, step=1.0):
    if stop is None:
        return torch.arange(start=0., end=float(start), step=float(step))
    else:
        return torch.arange(float(start), float(stop), float(step))

def clip(tensor, a_min=None, a_max=None, inplace=False):
    if a_max is None:
        a_max = torch.max(tensor)
    if a_min is None:
        a_min = torch.min(tensor)
    if inplace:
        return torch.clamp(tensor, a_min, a_max, out=tensor)
    else:
        return torch.clamp(tensor, a_min, a_max)

def all(tensor):
    return torch.sum(tensor != 0)

def transpose(tensor):
    axes = list(range(ndim(tensor)))[::-1]
    return tensor.permute(*axes)

def copy(tensor):
    return tensor.clone()

def moveaxis(tensor, source, target):
    axes = list(range(ndim(tensor)))
    try:
        axes.pop(source)
    except IndexError:
        raise ValueError('Source should verify 0 <= source < tensor.ndim'
                         'Got %d' % source)
    try:
        axes.insert(target, source)
    except IndexError:
        raise ValueError('Destination should verify 0 <= destination < tensor.ndim'
                         'Got %d' % target)
    return tensor.permute(*axes)

def kron(matrix1, matrix2):
    """Kronecker product"""
    s1, s2 = shape(matrix1)
    s3, s4 = shape(matrix2)
    return reshape(
        reshape(matrix1, (s1, 1, s2, 1))*reshape(matrix2, (1, s3, 1, s4)),
        (s1*s3, s2*s4))


def solve(matrix1, matrix2):
    solution, _ = torch.gesv(matrix2, matrix1)
    return solution


def norm(tensor, order=2, axis=None):
    """Computes the l-`order` norm of tensor.

    Parameters
    ----------
    tensor : ndarray
    order : int
    axis : int

    Returns
    -------
    float or tensor
        If `axis` is provided returns a tensor.
    """
    # pytorch does not accept `None` for any keyword arguments. additionally,
    # pytorch doesn't seems to support keyword arguments in the first place
    kwds = {}
    if axis is not None:
        kwds['dim'] = axis
    if order and order != 'inf':
        kwds['p'] = order

    if order == 'inf':
        res = torch.max(torch.abs(tensor), **kwds)
        if axis is not None:
            return res[0]  # ignore indices output
        return res
    return torch.norm(tensor, **kwds)

def mean(tensor, axis=None):
    if axis is None:
        return torch.mean(tensor)
    else:
        return torch.mean(tensor, dim=axis)

def sum(tensor, axis=None):
    if axis is None:
        return torch.sum(tensor)
    else:
        return torch.sum(tensor, dim=axis)

def concatenate(tensors, axis=0):
    return torch.cat(tensors, dim=axis)

def kr(matrices):
    """Khatri-Rao product of a list of matrices

        This can be seen as a column-wise kronecker product.

    Parameters
    ----------
    matrices : ndarray list
        list of matrices with the same number of columns, i.e.::

            for i in len(matrices):
                matrices[i].shape = (n_i, m)

    Returns
    -------
    khatri_rao_product: matrix of shape ``(prod(n_i), m)``
        where ``prod(n_i) = prod([m.shape[0] for m in matrices])``
        i.e. the product of the number of rows of all the matrices in the product.

    Notes
    -----
    Mathematically:

    .. math::
         \\text{If every matrix } U_k \\text{ is of size } (I_k \\times R),\\\\
         \\text{Then } \\left(U_1 \\bigodot \\cdots \\bigodot U_n \\right) \\text{ is of size } (\\prod_{k=1}^n I_k \\times R)
    """
    if len(matrices) < 2:
        raise ValueError('kr requires a list of at least 2 matrices, but {} given.'.format(len(matrices)))

    n_col = shape(matrices[0])[1]
    for i, e in enumerate(matrices[1:]):
        if not i:
            res = matrices[0]
        s1, s2 = shape(res)
        s3, s4 = shape(e)
        if not s2 == s4 == n_col:
            raise ValueError('All matrices should have the same number of columns.')
        res = reshape(reshape(res, (s1, 1, s2))*reshape(e, (1, s3, s4)),
                      (-1, n_col))
    return res


def partial_svd(matrix, n_eigenvecs=None):
    """Computes a fast partial SVD on `matrix`

        if `n_eigenvecs` is specified, sparse eigendecomposition
        is used on either matrix.dot(matrix.T) or matrix.T.dot(matrix)

    Parameters
    ----------
    matrix : 2D-array
    n_eigenvecs : int, optional, default is None
        if specified, number of eigen[vectors-values] to return

    Returns
    -------
    U : 2D-array
        of shape (matrix.shape[0], n_eigenvecs)
        contains the right singular vectors
    S : 1D-array
        of shape (n_eigenvecs, )
        contains the singular values of `matrix`
    V : 2D-array
        of shape (n_eigenvecs, matrix.shape[1])
        contains the left singular vectors
    """
    # Check that matrix is... a matrix!
    if ndim(matrix) != 2:
        raise ValueError('matrix be a matrix. matrix.ndim is {} != 2'.format(
            ndim(matrix)))

    # Choose what to do depending on the params
    dim_1, dim_2 = matrix.shape
    if dim_1 <= dim_2:
        min_dim = dim_1
    else:
        min_dim = dim_2

    if n_eigenvecs is None or n_eigenvecs >= min_dim:
	# Default on standard SVD
        try:
            U, S, V = torch.svd(matrix, some=False)
            U, S, V = U[:, :n_eigenvecs], S[:n_eigenvecs], V.t()[:n_eigenvecs, :]
            return U, S, V

        except RuntimeError: # Probably ran out of memory..
            ctx = context(matrix)
            matrix = to_numpy(matrix)
            U, S, V = scipy.linalg.svd(matrix)
            U, S, V = U[:, :n_eigenvecs], S[:n_eigenvecs], V[:n_eigenvecs, :]
            return tensor(U, **ctx), tensor(S, **ctx), tensor(V, **ctx)

    else:
        ctx = context(matrix)
        matrix = to_numpy(matrix)
        # We can perform a partial SVD
        # First choose whether to use X * X.T or X.T *X
        if dim_1 < dim_2:
            S, U = scipy.sparse.linalg.eigsh(numpy.dot(matrix, matrix.T.conj()), k=n_eigenvecs, which='LM')
            S = numpy.sqrt(S)
            V = numpy.dot(matrix.T.conj(), U * 1/S.reshape((1, -1)))
        else:
            S, V = scipy.sparse.linalg.eigsh(numpy.dot(matrix.T.conj(), matrix), k=n_eigenvecs, which='LM')
            S = numpy.sqrt(S)
            U = numpy.dot(matrix, V) * 1/S.reshape((1, -1))

        # WARNING: here, V is still the transpose of what it should be
        U, S, V = U[:, ::-1], S[::-1], V[:, ::-1]
        return tensor(U, **ctx), tensor(S, **ctx), tensor(V.T.conj(), **ctx)
back to top