Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
test_tools.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
def test_unfolding():
t = torch.rand((30, 10, 20, 10))
assert torch.allclose(tn.unfolding(t, 2, batch=False), t.permute(2, 0, 1, 3).reshape(20, -1))
assert torch.allclose(tn.unfolding(t, 2, batch=True), t.permute(0, 3, 1, 2).reshape(30, 10, -1))
def test_cat():
for i in range(100):
N = np.random.randint(1, 4)
shape1 = np.random.randint(1, 10, N)
mode = np.random.randint(N)
shape2 = shape1.copy()
shape2[mode] = np.random.randint(1, 10)
t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2)
t2 = tn.rand(shape2, ranks_tt=2)
gt = np.concatenate([t1.numpy(), t2.numpy()], mode)
assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7
def test_cumsum():
for i in range(100):
N = np.random.randint(1, 4)
howmany = 1
modes = np.random.choice(N, howmany, replace=False)
shape = np.random.randint(1, 10, N)
t = tn.rand(shape, ranks_tt=2, ranks_tucker=2)
assert np.linalg.norm(tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7
Computing file changes ...