https://github.com/tensorly/tensorly
Tip revision: 0fa8ab5b81e54410ee9b34a1ac5e45eebda7e387 authored by Jean Kossaifi on 06 August 2019, 17:16:06 UTC
Upload only once
Upload only once
Tip revision: 0fa8ab5
mps_tensor.py
"""
Core operations on tensors in Matrix Product State (MPS) format, also known as Tensor-Train (TT)
"""
import tensorly as tl
def _validate_mps_tensor(mps_tensor):
factors = mps_tensor
n_factors = len(factors)
if n_factors < 2:
raise ValueError('A Matrix-Product-State (ttrain) tensor should be composed of at least two factors and a core.'
'However, {} factor was given.'.format(n_factors))
rank = []
shape = []
for index, factor in enumerate(factors):
current_rank, current_shape, next_rank = tl.shape(factor)
# Check that factors are third order tensors
if not tl.ndim(factor)==3:
raise ValueError('MPS expresses a tensor as third order factors (tt-cores).\n'
'However, tl.ndim(factors[{}]) = {}'.format(
index, tl.ndim(factor)))
# Consecutive factors should have matching ranks
if index and tl.shape(factors[index - 1])[2] != current_rank:
raise ValueError('Consecutive factors should have matchine ranks\n'
' -- e.g. tl.shape(factors[0])[2]) == tl.shape(factors[1])[0])\n'
'However, tl.shape(factor[{}])[2] == {} but'
' tl.shape(factor[{}])[0] == {} '.format(
index - 1, tl.shape(factors[index - 1])[2], index, current_rank))
# Check for boundary conditions
if (index == 0) and current_rank != 1:
raise ValueError('Boundary conditions dictate factor[0].shape[0] == 1.'
'However, got factor[0].shape[0] = {}.'.format(
current_rank))
if (index == n_factors - 1) and next_rank != 1:
raise ValueError('Boundary conditions dictate factor[-1].shape[2] == 1.'
'However, got factor[{}].shape[2] = {}.'.format(
n_factors, next_rank))
shape.append(current_shape)
rank.append(current_rank)
# Add last rank (boundary condition)
rank.append(next_rank)
return tuple(shape), tuple(rank)
def mps_to_tensor(factors):
"""Returns the full tensor whose MPS decomposition is given by 'factors'
Re-assembles 'factors', which represent a tensor in MPS/TT format
into the corresponding full tensor
Parameters
----------
factors: list of 3D-arrays
MPS factors (known as core in TT terminology)
Returns
-------
output_tensor: ndarray
tensor whose MPS/TT decomposition was given by 'factors'
"""
full_shape = [f.shape[1] for f in factors]
full_tensor = tl.reshape(factors[0], (full_shape[0], -1))
for factor in factors[1:]:
rank_prev, _, rank_next = factor.shape
factor = tl.reshape(factor, (rank_prev, -1))
full_tensor = tl.dot(full_tensor, factor)
full_tensor = tl.reshape(full_tensor, (-1, rank_next))
return tl.reshape(full_tensor, full_shape)
def mps_to_unfolded(factors, mode):
"""Returns the unfolding matrix of a tensor given in MPS (or Tensor-Train) format
Reassembles a full tensor from 'factors' and returns its unfolding matrix
with mode given by 'mode'
Parameters
----------
factors: list of 3D-arrays
MPS factors
mode: int
unfolding matrix to be computed along this mode
Returns
-------
2-D array
unfolding matrix at mode given by 'mode'
"""
return tl.unfold(mps_to_tensor(factors), mode)
def mps_to_vec(factors):
"""Returns the tensor defined by its MPS format ('factors') into
its vectorized format
Parameters
----------
factors: list of 3D-arrays
MPS factors
Returns
-------
1-D array
vectorized format of tensor defined by 'factors'
"""
return tl.tensor_to_vec(mps_to_tensor(factors))