https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
test_tools.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)


def test_unfolding():
    t = torch.rand((30, 10, 20, 10))
    assert torch.allclose(tn.unfolding(t, 2, batch=False), t.permute(2, 0, 1, 3).reshape(20, -1))
    assert torch.allclose(tn.unfolding(t, 2, batch=True), t.permute(0, 3, 1, 2).reshape(30, 10, -1))

def test_cat():

    for i in range(100):
        N = np.random.randint(1, 4)
        shape1 = np.random.randint(1, 10, N)
        mode = np.random.randint(N)
        shape2 = shape1.copy()
        shape2[mode] = np.random.randint(1, 10)
        t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2)
        t2 = tn.rand(shape2, ranks_tt=2)
        gt = np.concatenate([t1.numpy(), t2.numpy()], mode)
        assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7


def test_cumsum():

    for i in range(100):
        N = np.random.randint(1, 4)
        howmany = 1
        modes = np.random.choice(N, howmany, replace=False)
        shape = np.random.randint(1, 10, N)
        t = tn.rand(shape, ranks_tt=2, ranks_tucker=2)
        assert np.linalg.norm(tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7
back to top