Revision 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC, committed by rballester on 19 February 2023, 19:35:27 UTC
1 parent be80cb2
test_ops.py
``````import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
from util import random_format

def check(t1, t2):
x1 = t1.torch()
x2 = t2.torch()
assert tn.relative_error(t1+t2, x1+x2) <= 1e-7
assert tn.relative_error(t1-t2, x1-x2) <= 1e-7
assert tn.relative_error(t1*t2, x1*x2) <= 1e-7
assert tn.relative_error(-t1+t2, -x1+x2) <= 1e-7

def test_ops():

for i in range(100):
t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=3, ranks_tucker=2)
t2 = tn.rand(t1.shape)
check(t1, t2)

shape = [8]*4

t1 = tn.rand(shape, ranks_tt=[3, None, None], ranks_cp=[None, None, 2, 2], ranks_tucker=5)
t2 = tn.rand(shape, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 3])
check(t1, t2)

t2 = t1*2
check(t1, t2)

for i in range(100):
t1 = random_format(shape)
t2 = random_format(shape)
check(t1, t2)

for i in range(10):
shape1 = np.random.randint(1, 10, 4)
shape2 = shape1.copy()
shape2[np.random.choice(len(shape1), np.random.randint(0, len(shape1)+1))] = 1
t1 = random_format(shape1)
t2 = random_format(shape2)
check(t1, t2)

def test_dot():

def check():
x1 = t1.torch()
x2 = t2.torch()
gt = torch.dot(x1.flatten(), x2.flatten())
assert tn.relative_error(tn.dot(t1, t2), gt) <= 1e-7

t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None)
t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None)
check()

t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=4)
t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=None)
check()

t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=None)
t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4)
check()

t1 = tn.rand(np.random.randint(1, 8, np.random.randint(1, 6)), ranks_tt=2, ranks_tucker=3)
t2 = tn.rand(t1.shape, ranks_tt=3, ranks_tucker=4)
check()

t1 = tn.rand([32] * 4, ranks_tt=[3, None, None], ranks_cp=[None, None, 10, 10], ranks_tucker=5)
t2 = tn.rand([32]*4, ranks_tt=[None, 2, None], ranks_cp=[4, None, None, 5])
check()

shape = [8]*4
for i in range(100):
t1 = random_format(shape)
t2 = random_format(shape)
check()

def test_stats():

def check():
x = t.torch()
assert tn.relative_error(tn.mean(t), torch.mean(x)) <= 1e-3
assert tn.relative_error(tn.var(t), torch.var(x)) <= 1e-3
assert tn.relative_error(tn.norm(t), torch.norm(x)) <= 1e-3

shape = [8]*4
for i in range(100):
t = random_format(shape)
check()

``````

Computing file changes ...