Revision 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC, committed by rballester on 19 February 2023, 19:35:27 UTC
1 parent be80cb2
test_round.py
``````import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)

def test_orthogonalization():

for i in range(100):
gt = tn.rand(np.random.randint(1, 8, np.random.randint(2, 6)))
t = gt.clone()
assert tn.relative_error(gt, t) <= 1e-7
t.left_orthogonalize(0)
assert tn.relative_error(gt, t) <= 1e-7
t.right_orthogonalize(t.dim()-1)
assert tn.relative_error(gt, t) <= 1e-7
t.orthogonalize(np.random.randint(t.dim()))
assert tn.relative_error(gt, t) <= 1e-7

def test_truncated_svd():
gt = torch.rand((2, 32, 32))
u, v = tn.truncated_svd(gt, batch=True)

for i in range(len(gt)):
u1, v1 = tn.truncated_svd(gt[i], batch=False)
assert torch.allclose(u1, u[i])
assert torch.allclose(v1, v[i])

def test_truncated_svd_eig():
gt = torch.rand((2, 32, 32))
u, v = tn.truncated_svd(gt, batch=True, algorithm='eig')

for i in range(len(gt)):
u1, v1 = tn.truncated_svd(gt[i], batch=False, algorithm='eig')
assert torch.allclose(u1, u[i])
assert torch.allclose(v1, v[i])

def test_round_tt_svd():

for i in range(100):
gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10))
gt.round_tt(1e-8, algorithm='svd')
t = gt+gt
t.round_tt(1e-8, algorithm='svd')
assert tn.relative_error(gt, t/2) <= 1e-4
assert max(gt.ranks_tt) == max(t.ranks_tt)

def test_round_tt_eig():

for i in range(100):
gt = tn.rand(np.random.randint(1, 8, np.random.randint(8, 10)), ranks_tt=np.random.randint(1, 10))
gt.round_tt(1e-8, algorithm='eig')
t = gt+gt
t.round_tt(1e-8, algorithm='eig')
assert tn.relative_error(gt, t/2) <= 1e-7

def test_round_tucker():
for i in range(100):
eps = np.random.rand()**2
gt = tn.rand([32]*4, ranks_tt=8, ranks_tucker=8)
t = gt.clone()
t.round_tucker(eps=eps)
assert tn.relative_error(gt, t) <= eps
``````

Computing file changes ...