https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
tools.py
import tntorch as tn
import torch
import numpy as np
import time
import scipy.fftpack


"""
Array-like manipulations
"""


def squeeze(t, dim=None):
    """
    Removes singleton dimensions.

    :param t: input :class:`Tensor`
    :param dim: which dim to delete. By default, all that have size 1

    :return: another :class:`Tensor`, without dummy (singleton) indices
    """

    if dim is None:
        dim = np.where([s == 1 for s in t.shape])[0]
    if not hasattr(dim, '__len__'):
        dim = [dim]

    assert np.all(np.array(t.shape)[dim] == 1)

    idx = [slice(None) for n in range(len(t.shape))]
    for m in dim:
        idx[m] = 0
    return t[tuple(idx)]


def unsqueeze(t, dim):
    """
    Inserts singleton dimensions at specified positions.

    :param t: input :class:`Tensor`
    :param dim: int or list of int

    :return: a :class:`Tensor` with dummy (singleton) dimensions inserted at the positions given by `dim`
    """

    if not hasattr(dim, '__len__'):
        dim = [dim]

    idx = [slice(None) for n in range(t.dim() + len(dim))]
    for d in dim:
        idx[d] = None
    return t[tuple(idx)]


def cat(*ts, dim):
    """
    Concatenate two or more tensors along a given dim, similarly to PyTorch's `cat()`.

    :param ts: a list of :class:`Tensor`
    :param dim: an int

    :return: a :class:`Tensor` of the same shape as all tensors in the list, except along `dim` where it has the sum of shapes
    """

    if hasattr(ts[0], '__len__'):
        ts = ts[0]
    if len(ts) == 1:
        return ts[0].clone()
    if any([any([t.shape[n] != ts[0].shape[n] for n in np.delete(range(ts[0].dim()), dim)]) for t in ts[1:]]):
        raise ValueError('To concatenate tensors, all must have the same shape along all but the given dim')

    shapes = np.array([t.shape[dim] for t in ts])
    sumshapes = np.concatenate([np.array([0]), np.cumsum(shapes)])
    for i in range(len(ts)):
        t = ts[i].clone()
        if t.Us[dim] is None:
            if t.cores[dim].dim() == 2:
                t.cores[dim] = torch.zeros(sumshapes[-1], t.cores[dim].shape[-1])
            else:
                t.cores[dim] = torch.zeros(t.cores[dim].shape[0], sumshapes[-1], t.cores[dim].shape[-1])
            t.cores[dim][..., sumshapes[i]:sumshapes[i + 1], :] += ts[i].cores[dim]
        else:
            t.Us[dim] = torch.zeros(sumshapes[-1], t.Us[dim].shape[-1])
            t.Us[dim][sumshapes[i]:sumshapes[i + 1], :] += ts[i].Us[dim]
        if i == 0:
            result = t
        else:
            result += t
    return result


def transpose(t):
    """
    Inverts the dimension order of a tensor, e.g. :math:`I_1 \\times I_2 \\times I_3` becomes :math:`I_3 \\times I_2 \\times I_1`.

    :param t: input tensor

    :return: another :class:`Tensor`, indexed by dimensions in inverse order
    """

    cores = []
    Us = []
    idxs = []
    for n in range(t.dim()-1, -1, -1):
        if t.cores[n].dim() == 3:
            cores.append(t.cores[n].permute(2, 1, 0))
        else:
            cores.append(t.cores[n])
        if t.Us[n] is None:
            Us.append(None)
        else:
            Us.append(t.Us[n].clone())
        try:
            idxs.append(t.idxs[n].clone())
        except Exception:
            idxs.append(None)
    return tn.Tensor(cores, Us, idxs)


def meshgrid(*axes, batch=False):
    """
    See NumPy's or PyTorch's `meshgrid()`.

    :param axes: a list of N ints or torch vectors

    :return: a list of N :class:`Tensor`, of N dimensions each
    """

    device = None
    if not hasattr(axes, '__len__'):
        axes = [axes]
    if hasattr(axes[0], '__len__'):
        axes = axes[0]
    if hasattr(axes[0], 'device'):
        device = axes[0].device
    axes = list(axes)
    N = len(axes)
    for n in range(N):
        if not hasattr(axes[n], '__len__'):
            axes[n] = torch.arange(axes[n], dtype=torch.get_default_dtype())

    tensors = []
    for n in range(N):
        cores = [torch.ones(1, len(ax), 1).to(device) for ax in axes]
        if isinstance(axes[n], torch.Tensor):
            cores[n] = axes[n].type(torch.get_default_dtype())
        else:
            cores[n] = torch.tensor(axes[n].type(torch.get_default_dtype()))
        cores[n] = cores[n][None, :, None].to(device)
        tensors.append(tn.Tensor(cores, device=device, batch=batch))
    return tensors


def flip(t, dim):
    """
    Reverses the order of a tensor along one or several dimensions; see NumPy's or PyTorch's `flip()`.

    :param t: input :class:`Tensor`
    :param dims: an int or list of ints

    :return: another :class:`Tensor` of the same shape
    """

    if not hasattr(dim, '__len__'):
        dim = [dim]

    shape = t.shape
    result = t.clone()
    for d in dim:
        idx = np.arange(shape[d]-1, -1, -1)
        if result.Us[d] is not None:
            result.Us[d] = result.Us[d][idx, :]
        else:
            result.cores[d] = result.cores[d][..., idx, :]
    return result


def unbind(t, dim):
    """
    Slices a tensor along a dimension and returns the slices as a sequence, like PyTorch's `unbind()`.

    :param t: input :class:`Tensor`
    :param dim: an int

    :return: a list of :class:`Tensor`, as many as `t.shape[dim]`
    """

    if dim < 0:
        dim += t.dim()
    return [t[[slice(None)]*dim + [sl] + [slice(None)]*(t.dim()-1-dim)] for sl in range(t.shape[dim])]


def unfolding(data, n, batch=False):
    """
    Computes the `n-th mode unfolding <https://epubs.siam.org/doi/pdf/10.1137/07070111X>`_ of a PyTorch tensor.

    :param data: a PyTorch tensor
    :param n: unfolding mode
    :param batch: Boolean

    :return: a PyTorch matrix
    """
    if batch:
        return data.permute(
            [0, n + 1] + \
            list(range(1, n + 1)) + \
            list(range(n + 2, data.dim()))
        ).reshape([data.shape[0], data.shape[n + 1], -1])
    else:
        return data.permute([n] + list(range(n)) + list(range(n + 1, data.dim()))).reshape([data.shape[n], -1])


def right_unfolding(core, batch=False):
    """
    Computes the `right unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>`_ of a 3D PyTorch tensor.

    :param core: a PyTorch tensor of shape :math:`I_1 \\times I_2 \\times I_3`
    :param batch: Boolean

    :return: a PyTorch matrix of shape :math:`I_1 \\times I_2 I_3`
    """
    if batch:
        return core.reshape([core.shape[0], core.shape[1], -1])
    else:
        return core.reshape([core.shape[0], -1])


def left_unfolding(core, batch=False):
    """
    Computes the `left unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>`_ of a 3D PyTorch tensor.

    :param core: a PyTorch tensor of shape :math:`I_1 \\times I_2 \\times I_3`

    :return: a PyTorch matrix of shape :math:`I_1 I_2 \\times I_3`
    """

    if batch:
        return core.reshape([core.shape[0], -1, core.shape[-1]])
    else:
        return core.reshape([-1, core.shape[-1]])


"""
Multilinear algebra
"""


def ttm(t, U, dim=None, transpose=False):
    """
    `Tensor-times-matrix (TTM) <https://epubs.siam.org/doi/pdf/10.1137/07070111X>`_ along one or several dimensions.

    :param t: input :class:`Tensor`
    :param U: one or several factors
    :param dim: one or several dimensions (may be vectors or matrices). If None, the first len(U) dims are assumed
    :param transpose: if False (default) the contraction is performed
     along U's rows, else along its columns

    :return: transformed :class:`Tensor`
    """

    if not isinstance(U, (list, tuple)):
        U = [U]
    if dim is None:
        dim = range(len(U))
    if not hasattr(dim, '__len__'):
        dim = [dim]
    dim = list(dim)
    for i in range(len(dim)):
        if dim[i] < 0:
            dim[i] += t.dim()

    cores = []
    Us = []
    for n in range(t.dim()):
        if n in dim:
            if transpose:
                factor = U[dim.index(n)].t()
            else:
                factor = U[dim.index(n)]
            if factor.dim() == 1 and not t.batch:
                factor = factor[None, ...]
            if factor.dim() == 2 and t.batch:
                factor = factor[:,  None, ...]
            if t.Us[n] is None:
                if t.batch:
                    if t.cores[n].dim() == 4:
                        cores.append(torch.einsum('biak,bja->bijk', (t.cores[n], factor)))
                    else:
                        cores.append(torch.einsum('bai,bja->bji', (t.cores[n], factor)))
                else:
                    if t.cores[n].dim() == 3:
                        cores.append(torch.einsum('iak,ja->ijk', (t.cores[n], factor)))
                    else:
                        cores.append(torch.einsum('ai,ja->ji', (t.cores[n], factor)))
                Us.append(None)
            else:
                cores.append(t.cores[n].clone())
                Us.append(torch.matmul(factor, t.Us[n]))
        else:
            cores.append(t.cores[n].clone())
            if t.Us[n] is None:
                Us.append(None)
            else:
                Us.append(t.Us[n].clone())
    return tn.Tensor(cores, Us=Us, idxs=t.idxs, batch=t.batch)


"""
Miscellaneous
"""


def mask(t, mask):
    """
    Masks a tensor. Basically an element-wise product, but this function makes sure slices are matched according to their "meaning" (as annotated by the tensor's `idx` field, if available)

    :param t: input :class:`Tensor`
    :param mask: a mask :class:`Tensor`

    :return: masked :class:`Tensor`
    """
    device = t.cores[0].device
    if not hasattr(t, 'idxs'):
        idxs = [np.arange(sh) for sh in t.shape]
    else:
        idxs = t.idxs
    cores = []
    Us = []
    for n in range(t.dim()):
        idx = np.array(idxs[n])
        idx[idx >= mask.shape[n]] = mask.shape[n]-1  # Clamp
        if mask.Us[n] is None:
            cores.append(mask.cores[n][..., idx, :].to(device))
            Us.append(None)
        else:
            cores.append(mask.cores[n].to(device))
            Us.append(mask.Us[n][idx, :])
    mask = tn.Tensor(cores, Us, device=device)
    return t*mask


def sample(t, P=1, seed=None):
    """
    Generate P points (with replacement) from a joint PDF distribution represented by a tensor.

    The tensor does not have to sum 1 (will be handled in a normalized form).

    :param t: a :class:`Tensor`
    :param P: how many samples to draw (default: 1)

    :return Xs: an integer matrix of size :math:`P \\times N`
    """

    def from_matrix(M):
        """
        Treat each row of a matrix M as a PMF and select a column per row according to it
        """

        M = np.abs(M)
        M /= torch.sum(M, dim=1)[:, None]  # Normalize row-wise
        M = np.hstack([np.zeros([M.shape[0], 1]), M])
        M = np.cumsum(M, axis=1)
        thresh = rng.random(M.shape[0])
        M -= thresh[:, np.newaxis]
        shiftand = np.logical_and(M[:, :-1] <= 0, M[:, 1:] > 0)  # Find where the sign switches
        return np.where(shiftand)[1]

    rng = np.random.default_rng(seed=seed)
    N = t.dim()
    tsum = tn.sum(t, dim=np.arange(N), keepdim=True).decompress_tucker_factors()
    Xs = torch.zeros([P, N])
    rights = [torch.ones(1)]
    for core in tsum.cores[::-1]:
        rights.append(torch.matmul(torch.sum(core, dim=1), rights[-1]))
    rights = rights[::-1]
    lefts = torch.ones([P, 1])
    t = t.decompress_tucker_factors()
    for mu in range(t.dim()):
        fiber = torch.einsum('ijk,k->ij', (t.cores[mu], rights[mu + 1]))
        per_point = torch.einsum('ij,jk->ik', (lefts, fiber))
        rows = from_matrix(per_point)
        Xs[:, mu] = torch.tensor(rows)
        lefts = torch.einsum('ij,jik->ik', (lefts, t.cores[mu][:, rows, :]))

    return Xs


def hash(t):
    """
    Computes an integer number that depends on the tensor entries (not on its internal compressed representation).

    We obtain it as :math:`\\langle T, W \\rangle`, where :math:`W` is a rank-1 tensor of weights selected at random (always the same seed).

    :return: an integer
    """

    gen = torch.Generator()
    gen.manual_seed(0)
    cores = [torch.ones(1, 1, 1) for n in range(t.dim())]
    Us = [torch.rand([sh, 1], generator=gen) for sh in t.shape]
    w = tn.Tensor(cores, Us)
    return t.dot(w)


def generate_basis(name, shape, orthonormal=False):
    """
    Generate a factor matrix whose columns are functions of a truncated basis.

    :param name: 'dct', 'legendre', 'chebyshev' or 'hermite'
    :param shape: two integers
    :param orthonormal: whether to orthonormalize the basis
    :param batch: Boolean

    :return: a PyTorch matrix of `shape`
    """

    if name == "dct":
        U = scipy.fftpack.dct(np.eye(shape[0]), norm="ortho")[:, :shape[1]]
    elif name == 'identity':
        U = np.eye(shape[0], shape[1])
    else:
        eval_points = np.linspace(-1, 1, shape[0])
        if name == "legendre":
            U = np.polynomial.legendre.legval(eval_points, np.eye(shape[0], shape[1])).T
        elif name == "chebyshev":
            U = np.polynomial.chebyshev.chebval(eval_points, np.eye(shape[0], shape[1])).T
        elif name == "hermite":
            U = np.polynomial.hermite.hermval(eval_points, np.eye(shape[0], shape[1])).T
        else:
            raise ValueError("Unsupported basis function")
    if orthonormal:
        U / np.sqrt(np.sum(U*U, axis=0))
    return torch.from_numpy(U)


def reduce(ts, function, eps=0, rmax=np.iinfo(np.int32).max, algorithm='svd', verbose=False, **kwargs):
    """
    Compute a tensor as a function to all tensors in a sequence.

    :Example 1 (addition):

    >>> import operator
    >>> tn.reduce([t1, t2], operator.add)

    :Example 2 (cat with bounded rank):

    >>> tn.reduce([t1, t2], tn.cat, rmax=10)

    :param ts: A generator (or list) of :class:`Tensor`
    :param eps: intermediate tensors will be rounded at this error when climbing up the hierarchy
    :param rmax: no node should exceed this number of ranks
    :param algorithm: passed to :func:`round.round()`
    :param verbose: Boolean

    :return: the reduced result
    """

    d = dict()
    start = time.time()
    for i, elem in enumerate(ts):
        if verbose and i % 100 == 0:
            print("reduce: element {}, time={:g}".format(i, time.time()-start))
        climb = 0  # For going up the tree
        while climb in d:
            elem = tn.round(function(d[climb], elem, **kwargs), eps=eps, rmax=rmax, algorithm=algorithm)
            d.pop(climb)
            climb += 1
        d[climb] = elem
    keys = list(d.keys())
    result = d[keys[0]]
    for key in keys[1:]:
        result = tn.round(function(result, d[key], **kwargs), eps=eps, rmax=rmax, algorithm=algorithm)
    return result


def pad(t, shape, dim=None, fill_value=0):
    """
    Pad a tensor with a constant value.

    :param t: N-dim input :class:`Tensor`
    :param shape: int or list of ints
    :param dim: int or list of ints (default: all modes)
    :param fill_value: default is 0

    :return: a :class:`Tensor` of size `shape` along the indicated modes
    """

    if dim is None:
        dim = range(t.dim())
    if not hasattr(dim, '__len__'):
        dim = [dim]
    if not hasattr(shape, '__len__'):
        shape = [shape]*len(dim)

    t = t.clone()
    for i in range(len(dim)):
        mult = 0
        if i == 0:
            mult = fill_value
        if t.Us[dim[i]] is None:
            if t.cores[dim[i]].dim() == 2:
                t.cores[dim[i]] = torch.cat([t.cores[dim[i]],
                                             mult * torch.ones(shape[i] - t.cores[dim[i]].shape[0],
                                                               t.cores[dim[i]].shape[1])], dim=0)
            else:
                t.cores[dim[i]] = torch.cat([t.cores[dim[i]],
                                             mult * torch.ones(t.cores[dim[i]].shape[0],
                                                               shape[i] - t.cores[dim[i]].shape[1],
                                                               t.cores[dim[i]].shape[2])], dim=1)
        else:
            t.Us[dim[i]] = torch.cat([t.Us[dim[i]],
                                      mult * torch.ones(shape[i] - t.Us[dim[i]].shape[0],
                                                        t.Us[dim[i]].shape[1])], dim=0)
    return t


def convolve(t1: tn.Tensor, t2: tn.Tensor, mode='full', **kwargs):
    """
    ND convolution of two compressed tensors. Note: this function uses cross-approximation to multiply both tensors in the Fourier frequency domain [1].

    [1] M. Rakhuba, I. Oseledets: "Fast multidimensional convolution in low-rank formats via cross approximation" (2014)

    :param t1: a `tn.Tensor`
    :param t2: a `tn.Tensor`
    :param mode: 'full' (default), 'same', or 'valid'. See `np.convolve`
    :param kwargs: to be passed to the cross-approximation
    :return: a `tn.Tensor`
    """

    N = t1.dim()
    assert N == t2.dim()

    t1 = t1.decompress_tucker_factors()
    t2 = t2.decompress_tucker_factors()
    t1f = tn.Tensor([torch.fft.fft(t1.cores[n], n=t1.shape[n]+t2.shape[n]-1, dim=1) for n in range(N)])
    t2f = tn.Tensor([torch.fft.fft(t2.cores[n], n=t1.shape[n]+t2.shape[n]-1, dim=1) for n in range(N)])

    def multr(x, y):
        a = torch.real(x)
        b = torch.imag(x)
        c = torch.real(y)
        d = torch.imag(y)
        return a*c - b*d

    def multi(x, y):
        a = torch.real(x)
        b = torch.imag(x)
        c = torch.real(y)
        d = torch.imag(y)
        return b*c + a*d

    t12fr = tn.cross(tensors=[t1f, t2f], function=multr, **kwargs)
    t12fi = tn.cross(tensors=[t1f, t2f], function=multi, **kwargs)
    t12fi.cores[-1] = t12fi.cores[-1]*1j
    t12r = tn.Tensor([torch.fft.ifft(t12fr.cores[n], dim=1) for n in range(N)])
    t12i = tn.Tensor([torch.fft.ifft(t12fi.cores[n], dim=1) for n in range(N)])
    t12 = tn.cross(tensors=[t12r, t12i], function=lambda x, y: torch.real(x)+torch.real(y), **kwargs)

    # Crop as needed
    if mode == 'same':
        for n in range(N):
            k = min(t1.shape[n], t2.shape[n])
            t12.cores[n] = t12.cores[n][:, k//2:k//2+max(t1.shape[n], t2.shape[n]), :]
    elif mode == 'valid':
        for n in range(N):
            k = min(t1.shape[n], t2.shape[n])
            t12.cores[n] = t12.cores[n][:, k-1:-(k-1), :]

    return t12


def shift_mode(t, n, shift, eps=1e-3):
    """
    Shift a mode back or forth within a TT. This is an *in-place* operation.

    :param t: a `tn.Tensor`
    :param n: which mode to move
    :param shift: how many positions to move. If positive move right, if negative move left
    :param eps: prescribed relative error tolerance. If 'same', ranks will be kept no larger than the original. Default is 1e-3
    """

    N = t.dim()
    assert 0 <= n + shift < N

    if shift == 0:
        return t

    if any([U is not None for U in t.Us]):
        t = t.decompress_tucker_factors(_clone=False)
    t.orthogonalize(n)
    cores = t.cores
    sign = np.sign(shift)
    for i in range(n, n + shift, sign):
        if sign == 1:
            c1 = i
            c2 = i+1
            left_ortho = True
        else:
            c1 = i-1
            c2 = i
            left_ortho = False
        R1 = cores[c1].shape[0]
        R2 = cores[c1].shape[2]
        R3 = cores[c2].shape[2]
        I1 = cores[c1].shape[1]
        I2 = cores[c2].shape[1]
        sc = torch.einsum('iaj,jbk->ibak', (cores[c1], cores[c2]))
        sc = sc.reshape(sc.shape[0]*sc.shape[1], sc.shape[2]*sc.shape[3])
        if eps == 'same':
            left, right = tn.truncated_svd(sc, eps=0, rmax=R2, left_ortho=left_ortho)
        elif eps >= 0:
            left, right = tn.truncated_svd(sc, eps=eps/np.sqrt(np.abs(shift)), left_ortho=left_ortho)
        else:
            raise ValueError("Relative error '{}' not recognized".format(eps))
        newR2 = left.shape[1]
        cores[c1] = left.reshape(R1, I2, newR2)
        cores[c2] = right.reshape(newR2, I1, R3)
back to top