swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Exact method for moments
Tip revision: 241bf7a
test_tools.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
def test_unfolding():
t = torch.rand((30, 10, 20, 10))
assert torch.allclose(tn.unfolding(t, 2, batch=False), t.permute(2, 0, 1, 3).reshape(20, -1))
assert torch.allclose(tn.unfolding(t, 2, batch=True), t.permute(0, 3, 1, 2).reshape(30, 10, -1))
def test_cat():
for i in range(100):
N = np.random.randint(1, 4)
shape1 = np.random.randint(1, 10, N)
mode = np.random.randint(N)
shape2 = shape1.copy()
shape2[mode] = np.random.randint(1, 10)
t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2)
t2 = tn.rand(shape2, ranks_tt=2)
gt = np.concatenate([t1.numpy(), t2.numpy()], mode)
assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7
def test_cumsum():
for i in range(100):
N = np.random.randint(1, 4)
howmany = 1
modes = np.random.choice(N, howmany, replace=False)
shape = np.random.randint(1, 10, N)
t = tn.rand(shape, ranks_tt=2, ranks_tucker=2)
assert np.linalg.norm(tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7