swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
test_tools.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)


def test_unfolding():
    t = torch.rand((30, 10, 20, 10))
    assert torch.allclose(tn.unfolding(t, 2, batch=False), t.permute(2, 0, 1, 3).reshape(20, -1))
    assert torch.allclose(tn.unfolding(t, 2, batch=True), t.permute(0, 3, 1, 2).reshape(30, 10, -1))

def test_cat():

    for i in range(100):
        N = np.random.randint(1, 4)
        shape1 = np.random.randint(1, 10, N)
        mode = np.random.randint(N)
        shape2 = shape1.copy()
        shape2[mode] = np.random.randint(1, 10)
        t1 = tn.rand(shape1, ranks_tt=2, ranks_tucker=2)
        t2 = tn.rand(shape2, ranks_tt=2)
        gt = np.concatenate([t1.numpy(), t2.numpy()], mode)
        assert np.linalg.norm(gt - tn.cat([t1, t2], dim=mode).numpy()) <= 1e-7


def test_cumsum():

    for i in range(100):
        N = np.random.randint(1, 4)
        howmany = 1
        modes = np.random.choice(N, howmany, replace=False)
        shape = np.random.randint(1, 10, N)
        t = tn.rand(shape, ranks_tt=2, ranks_tucker=2)
        assert np.linalg.norm(tn.cumsum(t, modes).numpy() - np.cumsum(t.numpy(), *modes)) <= 1e-7
back to top