##### https://github.com/rballester/tntorch
Tip revision: 3af563a
tools.py
import tntorch as tn
import torch
import numpy as np
import time
import scipy.fftpack

"""
Array-like manipulations
"""

def squeeze(t, dim=None):
"""
Removes singleton dimensions.

:param t: input :class:Tensor
:param dim: which dim to delete. By default, all that have size 1

:return: another :class:Tensor, without dummy (singleton) indices
"""

if dim is None:
dim = np.where([s == 1 for s in t.shape])
if not hasattr(dim, '__len__'):
dim = [dim]

assert np.all(np.array(t.shape)[dim] == 1)

idx = [slice(None) for n in range(len(t.shape))]
for m in dim:
idx[m] = 0
return t[tuple(idx)]

def unsqueeze(t, dim):
"""
Inserts singleton dimensions at specified positions.

:param t: input :class:Tensor
:param dim: int or list of int

:return: a :class:Tensor with dummy (singleton) dimensions inserted at the positions given by dim
"""

if not hasattr(dim, '__len__'):
dim = [dim]

idx = [slice(None) for n in range(t.dim() + len(dim))]
for d in dim:
idx[d] = None
return t[tuple(idx)]

def cat(*ts, dim):
"""
Concatenate two or more tensors along a given dim, similarly to PyTorch's cat().

:param ts: a list of :class:Tensor
:param dim: an int

:return: a :class:Tensor of the same shape as all tensors in the list, except along dim where it has the sum of shapes
"""

if hasattr(ts, '__len__'):
ts = ts
if len(ts) == 1:
return ts.clone()
if any([any([t.shape[n] != ts.shape[n] for n in np.delete(range(ts.dim()), dim)]) for t in ts[1:]]):
raise ValueError('To concatenate tensors, all must have the same shape along all but the given dim')

shapes = np.array([t.shape[dim] for t in ts])
sumshapes = np.concatenate([np.array(), np.cumsum(shapes)])
for i in range(len(ts)):
t = ts[i].clone()
if t.Us[dim] is None:
if t.cores[dim].dim() == 2:
t.cores[dim] = torch.zeros(sumshapes[-1], t.cores[dim].shape[-1])
else:
t.cores[dim] = torch.zeros(t.cores[dim].shape, sumshapes[-1], t.cores[dim].shape[-1])
t.cores[dim][..., sumshapes[i]:sumshapes[i + 1], :] += ts[i].cores[dim]
else:
t.Us[dim] = torch.zeros(sumshapes[-1], t.Us[dim].shape[-1])
t.Us[dim][sumshapes[i]:sumshapes[i + 1], :] += ts[i].Us[dim]
if i == 0:
result = t
else:
result += t
return result

def transpose(t):
"""
Inverts the dimension order of a tensor, e.g. :math:I_1 \\times I_2 \\times I_3 becomes :math:I_3 \\times I_2 \\times I_1.

:param t: input tensor

:return: another :class:Tensor, indexed by dimensions in inverse order
"""

cores = []
Us = []
idxs = []
for n in range(t.dim()-1, -1, -1):
if t.cores[n].dim() == 3:
cores.append(t.cores[n].permute(2, 1, 0))
else:
cores.append(t.cores[n])
if t.Us[n] is None:
Us.append(None)
else:
Us.append(t.Us[n].clone())
try:
idxs.append(t.idxs[n].clone())
except Exception:
idxs.append(None)
return tn.Tensor(cores, Us, idxs)

def meshgrid(*axes, batch=False):
"""
See NumPy's or PyTorch's meshgrid().

:param axes: a list of N ints or torch vectors

:return: a list of N :class:Tensor, of N dimensions each
"""

device = None
if not hasattr(axes, '__len__'):
axes = [axes]
if hasattr(axes, '__len__'):
axes = axes
if hasattr(axes, 'device'):
device = axes.device
axes = list(axes)
N = len(axes)
for n in range(N):
if not hasattr(axes[n], '__len__'):
axes[n] = torch.arange(axes[n], dtype=torch.get_default_dtype())

tensors = []
for n in range(N):
cores = [torch.ones(1, len(ax), 1).to(device) for ax in axes]
if isinstance(axes[n], torch.Tensor):
cores[n] = axes[n].type(torch.get_default_dtype())
else:
cores[n] = torch.tensor(axes[n].type(torch.get_default_dtype()))
cores[n] = cores[n][None, :, None].to(device)
tensors.append(tn.Tensor(cores, device=device, batch=batch))
return tensors

def flip(t, dim):
"""
Reverses the order of a tensor along one or several dimensions; see NumPy's or PyTorch's flip().

:param t: input :class:Tensor
:param dims: an int or list of ints

:return: another :class:Tensor of the same shape
"""

if not hasattr(dim, '__len__'):
dim = [dim]

shape = t.shape
result = t.clone()
for d in dim:
idx = np.arange(shape[d]-1, -1, -1)
if result.Us[d] is not None:
result.Us[d] = result.Us[d][idx, :]
else:
result.cores[d] = result.cores[d][..., idx, :]
return result

def unbind(t, dim):
"""
Slices a tensor along a dimension and returns the slices as a sequence, like PyTorch's unbind().

:param t: input :class:Tensor
:param dim: an int

:return: a list of :class:Tensor, as many as t.shape[dim]
"""

if dim < 0:
dim += t.dim()
return [t[[slice(None)]*dim + [sl] + [slice(None)]*(t.dim()-1-dim)] for sl in range(t.shape[dim])]

def unfolding(data, n, batch=False):
"""
Computes the n-th mode unfolding <https://epubs.siam.org/doi/pdf/10.1137/07070111X>_ of a PyTorch tensor.

:param data: a PyTorch tensor
:param n: unfolding mode
:param batch: Boolean

:return: a PyTorch matrix
"""
if batch:
return data.permute(
[0, n + 1] + \
list(range(1, n + 1)) + \
list(range(n + 2, data.dim()))
).reshape([data.shape, data.shape[n + 1], -1])
else:
return data.permute([n] + list(range(n)) + list(range(n + 1, data.dim()))).reshape([data.shape[n], -1])

def right_unfolding(core, batch=False):
"""
Computes the right unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>_ of a 3D PyTorch tensor.

:param core: a PyTorch tensor of shape :math:I_1 \\times I_2 \\times I_3
:param batch: Boolean

:return: a PyTorch matrix of shape :math:I_1 \\times I_2 I_3
"""
if batch:
return core.reshape([core.shape, core.shape, -1])
else:
return core.reshape([core.shape, -1])

def left_unfolding(core, batch=False):
"""
Computes the left unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>_ of a 3D PyTorch tensor.

:param core: a PyTorch tensor of shape :math:I_1 \\times I_2 \\times I_3

:return: a PyTorch matrix of shape :math:I_1 I_2 \\times I_3
"""

if batch:
return core.reshape([core.shape, -1, core.shape[-1]])
else:
return core.reshape([-1, core.shape[-1]])

"""
Multilinear algebra
"""

def ttm(t, U, dim=None, transpose=False):
"""
Tensor-times-matrix (TTM) <https://epubs.siam.org/doi/pdf/10.1137/07070111X>_ along one or several dimensions.

:param t: input :class:Tensor
:param U: one or several factors
:param dim: one or several dimensions (may be vectors or matrices). If None, the first len(U) dims are assumed
:param transpose: if False (default) the contraction is performed
along U's rows, else along its columns

:return: transformed :class:Tensor
"""

if not isinstance(U, (list, tuple)):
U = [U]
if dim is None:
dim = range(len(U))
if not hasattr(dim, '__len__'):
dim = [dim]
dim = list(dim)
for i in range(len(dim)):
if dim[i] < 0:
dim[i] += t.dim()

cores = []
Us = []
for n in range(t.dim()):
if n in dim:
if transpose:
factor = U[dim.index(n)].t()
else:
factor = U[dim.index(n)]
if factor.dim() == 1 and not t.batch:
factor = factor[None, ...]
if factor.dim() == 2 and t.batch:
factor = factor[:,  None, ...]
if t.Us[n] is None:
if t.batch:
if t.cores[n].dim() == 4:
cores.append(torch.einsum('biak,bja->bijk', (t.cores[n], factor)))
else:
cores.append(torch.einsum('bai,bja->bji', (t.cores[n], factor)))
else:
if t.cores[n].dim() == 3:
cores.append(torch.einsum('iak,ja->ijk', (t.cores[n], factor)))
else:
cores.append(torch.einsum('ai,ja->ji', (t.cores[n], factor)))
Us.append(None)
else:
cores.append(t.cores[n].clone())
Us.append(torch.matmul(factor, t.Us[n]))
else:
cores.append(t.cores[n].clone())
if t.Us[n] is None:
Us.append(None)
else:
Us.append(t.Us[n].clone())
return tn.Tensor(cores, Us=Us, idxs=t.idxs, batch=t.batch)

"""
Miscellaneous
"""

"""
Masks a tensor. Basically an element-wise product, but this function makes sure slices are matched according to their "meaning" (as annotated by the tensor's idx field, if available)

:param t: input :class:Tensor
:param mask: a mask :class:Tensor

:return: masked :class:Tensor
"""
device = t.cores.device
if not hasattr(t, 'idxs'):
idxs = [np.arange(sh) for sh in t.shape]
else:
idxs = t.idxs
cores = []
Us = []
for n in range(t.dim()):
idx = np.array(idxs[n])
Us.append(None)
else:

def sample(t, P=1, seed=None):
"""
Generate P points (with replacement) from a joint PDF distribution represented by a tensor.

The tensor does not have to sum 1 (will be handled in a normalized form).

:param t: a :class:Tensor
:param P: how many samples to draw (default: 1)

:return Xs: an integer matrix of size :math:P \\times N
"""

def from_matrix(M):
"""
Treat each row of a matrix M as a PMF and select a column per row according to it
"""

M = np.abs(M)
M /= torch.sum(M, dim=1)[:, None]  # Normalize row-wise
M = np.hstack([np.zeros([M.shape, 1]), M])
M = np.cumsum(M, axis=1)
thresh = rng.random(M.shape)
M -= thresh[:, np.newaxis]
shiftand = np.logical_and(M[:, :-1] <= 0, M[:, 1:] > 0)  # Find where the sign switches
return np.where(shiftand)

rng = np.random.default_rng(seed=seed)
N = t.dim()
tsum = tn.sum(t, dim=np.arange(N), keepdim=True).decompress_tucker_factors()
Xs = torch.zeros([P, N])
rights = [torch.ones(1)]
for core in tsum.cores[::-1]:
rights.append(torch.matmul(torch.sum(core, dim=1), rights[-1]))
rights = rights[::-1]
lefts = torch.ones([P, 1])
t = t.decompress_tucker_factors()
for mu in range(t.dim()):
fiber = torch.einsum('ijk,k->ij', (t.cores[mu], rights[mu + 1]))
per_point = torch.einsum('ij,jk->ik', (lefts, fiber))
rows = from_matrix(per_point)
Xs[:, mu] = torch.tensor(rows)
lefts = torch.einsum('ij,jik->ik', (lefts, t.cores[mu][:, rows, :]))

return Xs

def hash(t):
"""
Computes an integer number that depends on the tensor entries (not on its internal compressed representation).

We obtain it as :math:\\langle T, W \\rangle, where :math:W is a rank-1 tensor of weights selected at random (always the same seed).

:return: an integer
"""

gen = torch.Generator()
gen.manual_seed(0)
cores = [torch.ones(1, 1, 1) for n in range(t.dim())]
Us = [torch.rand([sh, 1], generator=gen) for sh in t.shape]
w = tn.Tensor(cores, Us)
return t.dot(w)

def generate_basis(name, shape, orthonormal=False):
"""
Generate a factor matrix whose columns are functions of a truncated basis.

:param name: 'dct', 'legendre', 'chebyshev' or 'hermite'
:param shape: two integers
:param orthonormal: whether to orthonormalize the basis
:param batch: Boolean

:return: a PyTorch matrix of shape
"""

if name == "dct":
U = scipy.fftpack.dct(np.eye(shape), norm="ortho")[:, :shape]
elif name == 'identity':
U = np.eye(shape, shape)
else:
eval_points = np.linspace(-1, 1, shape)
if name == "legendre":
U = np.polynomial.legendre.legval(eval_points, np.eye(shape, shape)).T
elif name == "chebyshev":
U = np.polynomial.chebyshev.chebval(eval_points, np.eye(shape, shape)).T
elif name == "hermite":
U = np.polynomial.hermite.hermval(eval_points, np.eye(shape, shape)).T
else:
raise ValueError("Unsupported basis function")
if orthonormal:
U / np.sqrt(np.sum(U*U, axis=0))

def reduce(ts, function, eps=0, rmax=np.iinfo(np.int32).max, algorithm='svd', verbose=False, **kwargs):
"""
Compute a tensor as a function to all tensors in a sequence.

>>> import operator

:Example 2 (cat with bounded rank):

>>> tn.reduce([t1, t2], tn.cat, rmax=10)

:param ts: A generator (or list) of :class:Tensor
:param eps: intermediate tensors will be rounded at this error when climbing up the hierarchy
:param rmax: no node should exceed this number of ranks
:param algorithm: passed to :func:round.round()
:param verbose: Boolean

:return: the reduced result
"""

d = dict()
start = time.time()
for i, elem in enumerate(ts):
if verbose and i % 100 == 0:
print("reduce: element {}, time={:g}".format(i, time.time()-start))
climb = 0  # For going up the tree
while climb in d:
elem = tn.round(function(d[climb], elem, **kwargs), eps=eps, rmax=rmax, algorithm=algorithm)
d.pop(climb)
climb += 1
d[climb] = elem
keys = list(d.keys())
result = d[keys]
for key in keys[1:]:
result = tn.round(function(result, d[key], **kwargs), eps=eps, rmax=rmax, algorithm=algorithm)
return result

"""
Pad a tensor with a constant value.

:param t: N-dim input :class:Tensor
:param shape: int or list of ints
:param dim: int or list of ints (default: all modes)
:param fill_value: default is 0

:return: a :class:Tensor of size shape along the indicated modes
"""

if dim is None:
dim = range(t.dim())
if not hasattr(dim, '__len__'):
dim = [dim]
if not hasattr(shape, '__len__'):
shape = [shape]*len(dim)

t = t.clone()
for i in range(len(dim)):
mult = 0
if i == 0:
mult = fill_value
if t.Us[dim[i]] is None:
if t.cores[dim[i]].dim() == 2:
t.cores[dim[i]] = torch.cat([t.cores[dim[i]],
mult * torch.ones(shape[i] - t.cores[dim[i]].shape,
t.cores[dim[i]].shape)], dim=0)
else:
t.cores[dim[i]] = torch.cat([t.cores[dim[i]],
mult * torch.ones(t.cores[dim[i]].shape,
shape[i] - t.cores[dim[i]].shape,
t.cores[dim[i]].shape)], dim=1)
else:
t.Us[dim[i]] = torch.cat([t.Us[dim[i]],
mult * torch.ones(shape[i] - t.Us[dim[i]].shape,
t.Us[dim[i]].shape)], dim=0)
return t

def convolve(t1: tn.Tensor, t2: tn.Tensor, mode='full', **kwargs):
"""
ND convolution of two compressed tensors. Note: this function uses cross-approximation to multiply both tensors in the Fourier frequency domain .

 M. Rakhuba, I. Oseledets: "Fast multidimensional convolution in low-rank formats via cross approximation" (2014)

:param t1: a tn.Tensor
:param t2: a tn.Tensor
:param mode: 'full' (default), 'same', or 'valid'. See np.convolve
:param kwargs: to be passed to the cross-approximation
:return: a tn.Tensor
"""

N = t1.dim()
assert N == t2.dim()

t1 = t1.decompress_tucker_factors()
t2 = t2.decompress_tucker_factors()
t1f = tn.Tensor([torch.fft.fft(t1.cores[n], n=t1.shape[n]+t2.shape[n]-1, dim=1) for n in range(N)])
t2f = tn.Tensor([torch.fft.fft(t2.cores[n], n=t1.shape[n]+t2.shape[n]-1, dim=1) for n in range(N)])

def multr(x, y):
a = torch.real(x)
b = torch.imag(x)
c = torch.real(y)
d = torch.imag(y)
return a*c - b*d

def multi(x, y):
a = torch.real(x)
b = torch.imag(x)
c = torch.real(y)
d = torch.imag(y)
return b*c + a*d

t12fr = tn.cross(tensors=[t1f, t2f], function=multr, **kwargs)
t12fi = tn.cross(tensors=[t1f, t2f], function=multi, **kwargs)
t12fi.cores[-1] = t12fi.cores[-1]*1j
t12r = tn.Tensor([torch.fft.ifft(t12fr.cores[n], dim=1) for n in range(N)])
t12i = tn.Tensor([torch.fft.ifft(t12fi.cores[n], dim=1) for n in range(N)])
t12 = tn.cross(tensors=[t12r, t12i], function=lambda x, y: torch.real(x)+torch.real(y), **kwargs)

# Crop as needed
if mode == 'same':
for n in range(N):
k = min(t1.shape[n], t2.shape[n])
t12.cores[n] = t12.cores[n][:, k//2:k//2+max(t1.shape[n], t2.shape[n]), :]
elif mode == 'valid':
for n in range(N):
k = min(t1.shape[n], t2.shape[n])
t12.cores[n] = t12.cores[n][:, k-1:-(k-1), :]

return t12

def shift_mode(t, n, shift, eps=1e-3):
"""
Shift a mode back or forth within a TT. This is an *in-place* operation.

:param t: a tn.Tensor
:param n: which mode to move
:param shift: how many positions to move. If positive move right, if negative move left
:param eps: prescribed relative error tolerance. If 'same', ranks will be kept no larger than the original. Default is 1e-3
"""

N = t.dim()
assert 0 <= n + shift < N

if shift == 0:
return t

if any([U is not None for U in t.Us]):
t = t.decompress_tucker_factors(_clone=False)
t.orthogonalize(n)
cores = t.cores
sign = np.sign(shift)
for i in range(n, n + shift, sign):
if sign == 1:
c1 = i
c2 = i+1
left_ortho = True
else:
c1 = i-1
c2 = i
left_ortho = False
R1 = cores[c1].shape
R2 = cores[c1].shape
R3 = cores[c2].shape
I1 = cores[c1].shape
I2 = cores[c2].shape
sc = torch.einsum('iaj,jbk->ibak', (cores[c1], cores[c2]))
sc = sc.reshape(sc.shape*sc.shape, sc.shape*sc.shape)
if eps == 'same':
left, right = tn.truncated_svd(sc, eps=0, rmax=R2, left_ortho=left_ortho)
elif eps >= 0:
left, right = tn.truncated_svd(sc, eps=eps/np.sqrt(np.abs(shift)), left_ortho=left_ortho)
else:
raise ValueError("Relative error '{}' not recognized".format(eps))
newR2 = left.shape
cores[c1] = left.reshape(R1, I2, newR2)
cores[c2] = right.reshape(newR2, I1, R3)