Raw File
from pytest import raises
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
from util import random_format


def check(x, t, idx):

    xidx = x[idx]
    tidx = t[idx].numpy()
    assert np.array_equal(xidx.shape, tidx.shape)
    assert np.linalg.norm(xidx - tidx) / np.linalg.norm(xidx) <= 1e-7


def test_squeeze():

    for i in range(100):
        x = np.random.randint(1, 3, np.random.randint(2, 10))
        t = tn.Tensor(x)
        x = np.squeeze(x)
        t = tn.squeeze(t)
        assert np.array_equal(x.shape, t.shape)


def test_slicing():

    t = tn.rand([1, 3, 1, 2, 1], ranks_tt=3, ranks_tucker=2)
    x = t.numpy()
    idx = slice(None)
    check(x, t, idx)
    idx = (slice(None), slice(1, None))
    check(x, t, idx)
    idx = (slice(None), slice(0, 2, None), slice(0, 1))
    check(x, t, idx)


def test_mixed():

    def check_one_tensor(t):

        x = t.numpy()

        idxs = []
        idxs.append(([0, 0, 0], None, None, 3))
        idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
        idxs.append((0, [0]))
        idxs.append(([0], [0]))
        idxs.append(([0], None, None, None, 0, 1))
        idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append(([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append((slice(None), slice(None), slice(None), 0))
        idxs.append((slice(None), slice(None), [0, 1], 0))
        idxs.append((0, np.array([0]), None, 0))
        idxs.append((slice(None), slice(None), slice(None), slice(None), None))
        idxs.append((None, slice(None), slice(None), slice(None), slice(None), None))
        idxs.append((None, slice(None), slice(None), slice(None), slice(None)))

        for idx in idxs:
            check(x, t, idx)

    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, ranks_tucker=2))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=None, ranks_tucker=2, ranks_cp=3))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=2, ranks_cp=[None, None, 3, 3]))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=[2, None, 2, None], ranks_cp=[None, None, 3, 3]))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[None, 4, 4], ranks_tucker=2, ranks_cp=[3, None, None, None]))

    for i in range(100):
        check_one_tensor(random_format([6, 7, 8, 9]))

    t = tn.rand([6, 7, 8, 9], ranks_cp=[3, 3, 3, 3])
    t.cores[-1] = t.cores[-1].permute(1, 0)[:, :, None]
    check_one_tensor(t)

    t = tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)
    check(t.numpy(), t, 0)
    check(t.numpy(), t, [0, 1])


def test_batch():
    def check_one_tensor(t):
        x = t.numpy()

        idxs = []
        idxs.append(([0, 0, 0], None, None, 3))
        idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
        idxs.append((0, [0]))
        idxs.append(([0], None, None, None, 0, 1))
        idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
        idxs.append((slice(None), slice(None), slice(None), 0))
        idxs.append((slice(None), slice(None), [0, 1], 0))
        idxs.append((0, np.array([0]), None, 0))
        idxs.append((slice(None), slice(None), slice(None), slice(None), None))

        for idx in idxs:
            check(x, t, idx)

    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tucker=3, batch=True))
    check_one_tensor(tn.rand([6, 7, 8, 9], ranks_cp=3, batch=True))

    with raises(ValueError) as exc_info:
        tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[None, ...]

    assert exc_info.type is ValueError

    with raises(ValueError) as exc_info2:
         tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[[0], [0]]

    assert exc_info2.type is ValueError

back to top