swh:1:snp:4e3e7077647a709f15b8c1b32ce7100175d0580b
Tip revision: fa272515cfb398d4b1ca644ba50f4f82dff63d20 authored by Jean Kossaifi on 03 September 2021, 17:55:21 UTC
Merge pull request #293 from IsabellLehmann/cmtf_als
Merge pull request #293 from IsabellLehmann/cmtf_als
Tip revision: fa27251
test_tt_matrix.py
import tensorly as tl
from tensorly import random
from ..tt_matrix import tt_matrix_to_matrix, tt_matrix_to_tensor, tt_matrix_to_vec
def test_tt_matrix_manipulation():
"""Test for tt_matrix manipulation"""
shape = (2, 2, 2, 3, 3, 3)
tt_matrix = random.random_tt_matrix(shape, rank=2, full=False)
rec = tt_matrix_to_tensor(tt_matrix)
assert(tl.shape(rec) == shape)
mat = tt_matrix_to_matrix(tt_matrix)
assert(tl.shape(mat) == (8, 27))
vec = tt_matrix_to_vec(tt_matrix)
assert(tl.shape(vec) == (8*27,))