"""Manipulation of matrices in the TT format"""
import tensorly as tl
from ._batched_tensordot import tensordot
def tt_matrix_to_tensor(tt_matrix):
"""Returns the full tensor whose TT-Matrix decomposition is given by 'factors'
Re-assembles 'factors', which represent a tensor in TT-Matrix format
into the corresponding full tensor
Parameters
----------
factors: list of 4D-arrays
TT-Matrix factors (known as core) of shape (rank_k, left_dim_k, right_dim_k, rank_{k+1})
Returns
-------
output_tensor: ndarray
tensor whose TT-Matrix decomposition was given by 'factors'
"""
# Each core is of shape (rank_left, size_in, size_out, rank_right)
_, in_shape, out_shape, _ = zip(*(tl.shape(f) for f in tt_matrix))
ndim = len(in_shape)
# Intertwine the dims
# full_shape = in_shape[0], out_shape[0], in_shape[1], ...
full_shape = sum(zip(*(in_shape, out_shape)), ())
order = list(range(0, ndim * 2, 2)) + list(range(1, ndim * 2, 2))
for i, factor in enumerate(tt_matrix):
if not i:
res = factor
else:
res = tensordot(res, factor, ([-1], [0]))
return tl.transpose(tl.reshape(res, full_shape), order)