Revision 85fbc1e20d309b4d1d31353bc8ed0f74bfd37050 authored by Meraj Hashemizadeh on 30 August 2021, 20:41:50 UTC, committed by Meraj Hashemizadeh on 30 August 2021, 20:41:50 UTC
1 parent 5a6992a
Raw File
test_tt_matrix.py
import tensorly as tl
from tensorly import random
from ..tt_matrix import tt_matrix_to_matrix, tt_matrix_to_tensor, tt_matrix_to_vec

def test_tt_matrix_manipulation():
    """Test for tt_matrix manipulation"""
    shape = (2, 2, 2, 3, 3, 3)
    tt_matrix = random.random_tt_matrix(shape, rank=2, full=False)
    rec = tt_matrix_to_tensor(tt_matrix)
    assert(tl.shape(rec) == shape)

    mat = tt_matrix_to_matrix(tt_matrix)
    assert(tl.shape(mat) == (8, 27))

    vec = tt_matrix_to_vec(tt_matrix)
    assert(tl.shape(vec) == (8*27,))
back to top