https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
test_round.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)


def test_orthogonalization():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6)))
        t = gt.clone()
        assert tn.relative_error(gt, t) <= 1e-7
        t.left_orthogonalize(0)
        assert tn.relative_error(gt, t) <= 1e-7
        t.right_orthogonalize(t.dim()-1)
        assert tn.relative_error(gt, t) <= 1e-7
        t.orthogonalize(np.random.randint(t.dim()))
        assert tn.relative_error(gt, t) <= 1e-7


def test_truncated_svd():
    gt = torch.rand((2, 32, 32))
    u, v = tn.truncated_svd(gt, batch=True)

    for i in range(len(gt)):
        u1, v1 = tn.truncated_svd(gt[i], batch=False)
        assert torch.allclose(u1, u[i])
        assert torch.allclose(v1, v[i])


def test_truncated_svd_eig():
    gt = torch.rand((2, 32, 32))
    u, v = tn.truncated_svd(gt, batch=True, algorithm='eig')

    for i in range(len(gt)):
        u1, v1 = tn.truncated_svd(gt[i], batch=False, algorithm='eig')
        assert torch.allclose(u1, u[i])
        assert torch.allclose(v1, v[i])


def test_round_tt_svd():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10))
        gt.round_tt(1e-8, algorithm='svd')
        t = gt+gt
        t.round_tt(1e-8, algorithm='svd')
        assert tn.relative_error(gt, t/2) <= 1e-4
        assert max(gt.ranks_tt) == max(t.ranks_tt)


def test_round_tt_eig():

    for i in range(100):
        gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10))
        gt.round_tt(1e-8, algorithm='eig')
        t = gt+gt
        t.round_tt(1e-8, algorithm='eig')
        assert tn.relative_error(gt, t/2) <= 1e-7


def test_round_tucker():
        for i in range(100):
            eps = np.random.rand()**2
            gt = tn.rand([32]*4, ranks_tt=8, ranks_tucker=8)
            t = gt.clone()
            t.round_tucker(eps=eps)
            assert tn.relative_error(gt, t) <= eps
back to top