swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
test_automata.py
import tntorch as tn
import numpy as np
import torch
torch.set_default_dtype(torch.float64)


def test_weight_mask():

    for N in range(1, 5):
        for k in range(1, N):
            gt = tn.automata.weight_mask(N, k)
            idx = torch.Tensor(np.array(np.unravel_index(np.arange(gt.numel(), dtype=int), list(gt.shape))).T)
            assert torch.norm((torch.sum(idx, dim=1).round() == k).float() - gt[idx].torch().round().float()) <= 1e-7

def test_accepted_inputs():

    for i in range(10):
        gt = tn.Tensor(torch.randint(0, 2, (1, 2, 3, 4)))
        idx = tn.automata.accepted_inputs(gt)
        assert len(idx) == round(tn.sum(gt).item())
        assert torch.norm(gt[idx].torch().double() - 1).item() <= 1e-7
back to top