Revision 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC, committed by rballester on 19 February 2023, 19:35:27 UTC
1 parent be80cb2
test_indexing.py
``````from pytest import raises
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
from util import random_format

def check(x, t, idx):

xidx = x[idx]
tidx = t[idx].numpy()
assert np.array_equal(xidx.shape, tidx.shape)
assert np.linalg.norm(xidx - tidx) / np.linalg.norm(xidx) <= 1e-7

def test_squeeze():

for i in range(100):
x = np.random.randint(1, 3, np.random.randint(2, 10))
t = tn.Tensor(x)
x = np.squeeze(x)
t = tn.squeeze(t)
assert np.array_equal(x.shape, t.shape)

def test_slicing():

t = tn.rand([1, 3, 1, 2, 1], ranks_tt=3, ranks_tucker=2)
x = t.numpy()
idx = slice(None)
check(x, t, idx)
idx = (slice(None), slice(1, None))
check(x, t, idx)
idx = (slice(None), slice(0, 2, None), slice(0, 1))
check(x, t, idx)

def test_mixed():

def check_one_tensor(t):

x = t.numpy()

idxs = []
idxs.append(([0, 0, 0], None, None, 3))
idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
idxs.append((0, [0]))
idxs.append(([0], [0]))
idxs.append(([0], None, None, None, 0, 1))
idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append(([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append((slice(None), slice(None), slice(None), 0))
idxs.append((slice(None), slice(None), [0, 1], 0))
idxs.append((0, np.array([0]), None, 0))
idxs.append((slice(None), slice(None), slice(None), slice(None), None))
idxs.append((None, slice(None), slice(None), slice(None), slice(None), None))
idxs.append((None, slice(None), slice(None), slice(None), slice(None)))

for idx in idxs:
check(x, t, idx)

check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, ranks_tucker=2))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=None, ranks_tucker=2, ranks_cp=3))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=2, ranks_cp=[None, None, 3, 3]))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=[2, None, 2, None], ranks_cp=[None, None, 3, 3]))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[None, 4, 4], ranks_tucker=2, ranks_cp=[3, None, None, None]))

for i in range(100):
check_one_tensor(random_format([6, 7, 8, 9]))

t = tn.rand([6, 7, 8, 9], ranks_cp=[3, 3, 3, 3])
t.cores[-1] = t.cores[-1].permute(1, 0)[:, :, None]
check_one_tensor(t)

t = tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)
check(t.numpy(), t, 0)
check(t.numpy(), t, [0, 1])

def test_batch():
def check_one_tensor(t):
x = t.numpy()

idxs = []
idxs.append(([0, 0, 0], None, None, 3))
idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
idxs.append((0, [0]))
idxs.append(([0], None, None, None, 0, 1))
idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append((slice(None), slice(None), slice(None), 0))
idxs.append((slice(None), slice(None), [0, 1], 0))
idxs.append((0, np.array([0]), None, 0))
idxs.append((slice(None), slice(None), slice(None), slice(None), None))

for idx in idxs:
check(x, t, idx)

check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tucker=3, batch=True))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_cp=3, batch=True))

with raises(ValueError) as exc_info:
tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[None, ...]

assert exc_info.type is ValueError

with raises(ValueError) as exc_info2:
tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[[0], [0]]

assert exc_info2.type is ValueError

``````

Computing file changes ...