swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
test_cross.py
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
from util import random_format


def test_domain():

    def function(Xs):
        return 1. / torch.sum(Xs, dim=1)

    domain = [torch.linspace(1, 10, 10) for n in range(3)]
    t = tn.cross(function=function, domain=domain, ranks_tt=3, function_arg='matrix')
    gt = torch.meshgrid(domain, indexing='ij')
    gt = 1. / sum(gt)

    assert tn.relative_error(gt, t) < 5e-2


def test_tensors():

    for i in range(100):
        t = random_format([10] * 6)
        t2 = tn.cross(function=lambda x: x, tensors=t, ranks_tt=15, verbose=False)
        assert tn.relative_error(t, t2) < 1e-6

    t = tn.rand([10] * 6, ranks_tt=10)
    _, info = tn.cross(function=lambda x: x, tensors=[t], ranks_tt=15, verbose=False, return_info=True)
    t2 = tn.cross_forward(info, function=lambda x: x, tensors=t)
    assert tn.relative_error(t, t2) < 1e-6


def test_ops():

    x, y, z, w = tn.meshgrid([32]*4)
    t = x + y + z + w + 1
    assert tn.relative_error(1/t.torch(), 1/t) < 1e-4
    assert tn.relative_error(torch.cos(t.torch()), tn.cos(t)) < 1e-4
    assert tn.relative_error(torch.exp(t.torch()), tn.exp(t)) < 1e-4
back to top