Raw File
Tip revision: 0fa8ab5b81e54410ee9b34a1ac5e45eebda7e387 authored by Jean Kossaifi on 06 August 2019, 17:16:06 UTC
Upload only once
Tip revision: 0fa8ab5
Core operations on tensors in Matrix Product State (MPS) format, also known as Tensor-Train (TT)

import tensorly as tl

def _validate_mps_tensor(mps_tensor):
    factors = mps_tensor
    n_factors = len(factors)
    if n_factors < 2:
        raise ValueError('A Matrix-Product-State (ttrain) tensor should be composed of at least two factors and a core.'
                         'However, {} factor was given.'.format(n_factors))

    rank = []
    shape = []
    for index, factor in enumerate(factors):
        current_rank, current_shape, next_rank = tl.shape(factor)

        # Check that factors are third order tensors
        if not tl.ndim(factor)==3:
            raise ValueError('MPS expresses a tensor as third order factors (tt-cores).\n'
                             'However, tl.ndim(factors[{}]) = {}'.format(
                                 index, tl.ndim(factor)))
        # Consecutive factors should have matching ranks
        if index and tl.shape(factors[index - 1])[2] != current_rank:
            raise ValueError('Consecutive factors should have matchine ranks\n'
                             ' -- e.g. tl.shape(factors[0])[2]) == tl.shape(factors[1])[0])\n'
                             'However, tl.shape(factor[{}])[2] == {} but'
                             ' tl.shape(factor[{}])[0] == {} '.format(
                              index - 1, tl.shape(factors[index - 1])[2], index, current_rank))
        # Check for boundary conditions
        if (index == 0) and current_rank != 1:
            raise ValueError('Boundary conditions dictate factor[0].shape[0] == 1.'
                             'However, got factor[0].shape[0] = {}.'.format(
        if (index == n_factors - 1) and next_rank != 1:
            raise ValueError('Boundary conditions dictate factor[-1].shape[2] == 1.'
                             'However, got factor[{}].shape[2] = {}.'.format(
                              n_factors, next_rank))
    # Add last rank (boundary condition)
    return tuple(shape), tuple(rank)

def mps_to_tensor(factors):
    """Returns the full tensor whose MPS decomposition is given by 'factors'

        Re-assembles 'factors', which represent a tensor in MPS/TT format
        into the corresponding full tensor

    factors: list of 3D-arrays
              MPS factors (known as core in TT terminology)

    output_tensor: ndarray
                   tensor whose MPS/TT decomposition was given by 'factors'
    full_shape = [f.shape[1] for f in factors]
    full_tensor = tl.reshape(factors[0], (full_shape[0], -1))

    for factor in factors[1:]:
        rank_prev, _, rank_next = factor.shape
        factor = tl.reshape(factor, (rank_prev, -1))
        full_tensor =, factor)
        full_tensor = tl.reshape(full_tensor, (-1, rank_next))

    return tl.reshape(full_tensor, full_shape)

def mps_to_unfolded(factors, mode):
    """Returns the unfolding matrix of a tensor given in MPS (or Tensor-Train) format

    Reassembles a full tensor from 'factors' and returns its unfolding matrix
    with mode given by 'mode'

    factors: list of 3D-arrays
              MPS factors
    mode: int
          unfolding matrix to be computed along this mode

    2-D array
    unfolding matrix at mode given by 'mode'
    return tl.unfold(mps_to_tensor(factors), mode)

def mps_to_vec(factors):
    """Returns the tensor defined by its MPS format ('factors') into
       its vectorized format

    factors: list of 3D-arrays
              MPS factors

    1-D array
    vectorized format of tensor defined by 'factors'

    return tl.tensor_to_vec(mps_to_tensor(factors))

back to top