Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

Revision e54a8ffb2a30952542234280d443c6dd6e2649d5 authored by TUNA Caglayan on 26 April 2021, 13:54:29 UTC, committed by TUNA Caglayan on 12 May 2021, 12:26:22 UTC
decimal and tensor to vec
1 parent 5412537
  • Files
  • Changes
  • 1633999
  • /
  • tensorly
  • /
  • tt_tensor.py
Raw File Download
Permalinks

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • revision
  • directory
  • content
revision badge
swh:1:rev:e54a8ffb2a30952542234280d443c6dd6e2649d5
directory badge Iframe embedding
swh:1:dir:7658cecc6713fc2354c396ee43be3402a7ecb355
content badge Iframe embedding
swh:1:cnt:ef4804878d374aabf35c175c8933d298751d23f1
Citations

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • revision
  • directory
  • content
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
tt_tensor.py
"""
Core operations on tensors in Tensor-Train (TT) format, also known as Matrix-Product-State (MPS)
"""

import tensorly as tl
from ._factorized_tensor import FactorizedTensor
from .utils import DefineDeprecated
import numpy as np
from scipy.optimize import brentq
import warnings

def _validate_tt_tensor(tt_tensor):
    factors = tt_tensor
    n_factors = len(factors)

    if isinstance(tt_tensor, TTTensor):
        # it's already been validated at creation
        return tt_tensor.shape, tt_tensor.rank
    elif isinstance(tt_tensor, (float, int)): #0-order tensor
        return 0, 0

    rank = []
    shape = []
    for index, factor in enumerate(factors):
        current_rank, current_shape, next_rank = tl.shape(factor)

        # Check that factors are third order tensors
        if not tl.ndim(factor)==3:
            raise ValueError('TT expresses a tensor as third order factors (tt-cores).\n'
                             'However, tl.ndim(factors[{}]) = {}'.format(
                                 index, tl.ndim(factor)))
        # Consecutive factors should have matching ranks
        if index and tl.shape(factors[index - 1])[2] != current_rank:
            raise ValueError('Consecutive factors should have matching ranks\n'
                             ' -- e.g. tl.shape(factors[0])[2]) == tl.shape(factors[1])[0])\n'
                             'However, tl.shape(factor[{}])[2] == {} but'
                             ' tl.shape(factor[{}])[0] == {} '.format(
                              index - 1, tl.shape(factors[index - 1])[2], index, current_rank))
        # Check for boundary conditions
        if (index == 0) and current_rank != 1:
            raise ValueError('Boundary conditions dictate factor[0].shape[0] == 1.'
                             'However, got factor[0].shape[0] = {}.'.format(
                              current_rank))
        if (index == n_factors - 1) and next_rank != 1:
            raise ValueError('Boundary conditions dictate factor[-1].shape[2] == 1.'
                             'However, got factor[{}].shape[2] = {}.'.format(
                              n_factors, next_rank))
    
        shape.append(current_shape)
        rank.append(current_rank)
        
    # Add last rank (boundary condition)
    rank.append(next_rank)
        
    return tuple(shape), tuple(rank)


def tt_to_tensor(factors):
    """Returns the full tensor whose TT decomposition is given by 'factors'

        Re-assembles 'factors', which represent a tensor in TT/Matrix-Product-State format
        into the corresponding full tensor

    Parameters
    ----------
    factors : list of 3D-arrays
              TT factors (TT-cores)

    Returns
    -------
    output_tensor : ndarray
                   tensor whose TT/MPS decomposition was given by 'factors'
    """
    if isinstance(factors, (float, int)): #0-order tensor
        return factors

    full_shape = [f.shape[1] for f in factors]
    full_tensor = tl.reshape(factors[0], (full_shape[0], -1))

    for factor in factors[1:]:
        rank_prev, _, rank_next = factor.shape
        factor = tl.reshape(factor, (rank_prev, -1))
        full_tensor = tl.dot(full_tensor, factor)
        full_tensor = tl.reshape(full_tensor, (-1, rank_next))

    return tl.reshape(full_tensor, full_shape)


def tt_to_unfolded(factors, mode):
    """Returns the unfolding matrix of a tensor given in TT (or Tensor-Train) format

    Reassembles a full tensor from 'factors' and returns its unfolding matrix
    with mode given by 'mode'

    Parameters
    ----------
    factors: list of 3D-arrays
              TT factors
    mode: int
          unfolding matrix to be computed along this mode

    Returns
    -------
    2-D array
    unfolding matrix at mode given by 'mode'
    """
    return tl.unfold(tt_to_tensor(factors), mode)


def tt_to_vec(factors):
    """Returns the tensor defined by its TT format ('factors') into
       its vectorized format

    Parameters
    ----------
    factors: list of 3D-arrays
              TT factors

    Returns
    -------
    1-D array
    vectorized format of tensor defined by 'factors'
    """
    return tl.tensor_to_vec(tt_to_tensor(factors))

def _tt_n_param(tensor_shape, rank):
    """Number of parameters of a MPS decomposition for a given `rank` and full `tensor_shape`.

    Parameters
    ----------
    tensor_shape : int tuple
        shape of the full tensor to decompose (or approximate)
    
    rank : tuple
        rank of the MPS decomposition
    
    Returns
    -------
    n_params : int
        Number of parameters of a MPS decomposition of rank `rank` of a full tensor of shape `tensor_shape`
    """
    factor_params = []
    for i, s in enumerate(tensor_shape):
        factor_params.append(rank[i]*s*rank[i+1])
    return np.sum(factor_params)


def validate_tt_rank(tensor_shape, rank='same', constant_rank=False, rounding='round'):
    """Returns the rank of a TT Decomposition

    Parameters
    ----------
    tensor_shape : tupe
        shape of the tensor to decompose
    rank : {'same', float, tuple, int}, default is same
        way to determine the rank, by default 'same'
        if 'same': rank is computed to keep the number of parameters (at most) the same
        if float, computes a rank so as to keep rank percent of the original number of parameters
        if int or tuple, just returns rank
    constant_rank : bool, default is False
        * if True, the *same* rank will be chosen for each modes
        * if False (default), the rank of each mode will be proportional to the corresponding tensor_shape

        *used only if rank == 'same' or 0 < rank <= 1*

    rounding = {'round', 'floor', 'ceil'}

    Returns
    -------
    rank : int tuple
        rank of the decomposition
    """
    if rounding == 'ceil':
        rounding_fun = np.ceil
    elif rounding == 'floor':
        rounding_fun = np.floor
    elif rounding == 'round':
        rounding_fun = np.round
    else:
        raise ValueError(f'Rounding should be round, floor or ceil, but got {rounding}')

    if rank == 'same':
        rank = float(1)

    if isinstance(rank, float) and constant_rank:
        # Choose the *same* rank for each mode
        n_param_tensor = np.prod(tensor_shape)*rank
        order = len(tensor_shape)

        if order == 2:
            rank = (1, n_param_tensor / (tensor_shape[0] + tensor_shape[1]), 1)
            warnings.warn(f'Determining the tt-rank for the trivial case of a matrix (order 2 tensor) of shape {tensor_shape}, not a higher-order tensor.')

        # R_k I_k R_{k+1} = R^2 I_k
        a = np.sum(tensor_shape[1:-1])
    
        # Border rank of 1, R_0 = R_N = 1
        # First and last factor of size I_0 R and I_N R
        b = np.sum(tensor_shape[0] + tensor_shape[-1])

        # We want the number of params of decomp (=sum of params of factors)
        # To be equal to c = \prod_k I_k
        c = -n_param_tensor
        delta = np.sqrt(b**2 - 4*a*c)

        # We get the non-negative solution
        solution = int(rounding_fun((- b + delta)/(2*a)))
        rank = rank=(1, ) + (solution, )*(order-1) + (1, )

    elif isinstance(rank, float):
        # Choose a rank proportional to the size of each mode
        # The method is similar to the above one for constant_rank == True
        order = len(tensor_shape)
        avg_dim = [(tensor_shape[i]+tensor_shape[i+1])/2 for i in range(order - 1)]
        if len(avg_dim) > 1:
            a = sum(avg_dim[i-1]*tensor_shape[i]*avg_dim[i] for i in range(1, order - 1))
        else:
            warnings.warn(f'Determining the tt-rank for the trivial case of a matrix (order 2 tensor) of shape {tensor_shape}, not a higher-order tensor.')
            a = avg_dim[0]**2*tensor_shape[0]
        b = tensor_shape[0]*avg_dim[0] + tensor_shape[-1]*avg_dim[-1]
        c = -np.prod(tensor_shape)*rank
        delta = np.sqrt(b**2 - 4*a*c)

        # We get the non-negative solution
        fraction_param = (- b + delta)/(2*a)
        rank = tuple([max(int(rounding_fun(d*fraction_param)), 1) for d in avg_dim])
        rank = (1, ) + rank + (1, )

    else:
        # Check user input for potential errors
        n_dim = len(tensor_shape)
        if isinstance(rank, int):
            rank = [1] + [rank] * (n_dim-1) + [1]
        elif n_dim+1 != len(rank):
            message = 'Provided incorrect number of ranks. Should verify len(rank) == tl.ndim(tensor)+1, but len(rank) = {} while tl.ndim(tensor) + 1  = {}'.format(
                len(rank), n_dim + 1)
            raise(ValueError(message))

        # Initialization
        if rank[0] != 1:
            message = 'Provided rank[0] == {} but boundaring conditions dictatate rank[0] == rank[-1] == 1: setting rank[0] to 1.'.format(rank[0])
            raise ValueError(message)
        if rank[-1] != 1:
            message = 'Provided rank[-1] == {} but boundaring conditions dictatate rank[0] == rank[-1] == 1: setting rank[-1] to 1.'.format(rank[0])
            raise ValueError(message)

    return list(rank)


class TTTensor(FactorizedTensor):
    def __init__(self, factors, inplace=False):
        super().__init__()
        
        # Will raise an error if invalid
        shape, rank = _validate_tt_tensor(factors)

        self.shape = tuple(shape)
        self.rank = tuple(rank)
        self.factors = factors
    
    def __getitem__(self, index):
        return self.factors[index]
    
    def __setitem__(self, index, value):
        self.factors[index] = value
    
    def __iter__(self):
        for index in range(len(self)):
            yield self[index]
        
    def __len__(self):
        return len(self.factors)
    
    def __repr__(self):
        message = 'factors list : rank-{} matrix-product-state tensor of shape {} '.format(
            self.rank, self.shape)
        return message
    
    def to_tensor(self):
        return tt_to_tensor(self)
    
    def to_unfolding(self, mode):
        return tt_to_unfolded(self, mode)
    
    def to_vec(self):
        return tt_to_vec(self)

mps_to_tensor = DefineDeprecated(deprecated_name='mps_to_tensor', use_instead=tt_to_tensor)
mps_to_unfolded = DefineDeprecated(deprecated_name='mps_to_unfolded', use_instead=tt_to_unfolded)
mps_to_vec = DefineDeprecated(deprecated_name='mps_to_vec', use_instead=tt_to_vec)
_validate_mps_tensor = DefineDeprecated(deprecated_name='_validate_mps_tensor', use_instead=_validate_tt_tensor)
The diff you're trying to view is too large. Only the first 1000 changed files have been loaded.
Showing with 0 additions and 0 deletions (0 / 0 diffs computed)
swh spinner

Computing file changes ...

Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Contact— JavaScript license information— Web API

back to top