test_indexing.py
from pytest import raises
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)
from util import random_format
def check(x, t, idx):
xidx = x[idx]
tidx = t[idx].numpy()
assert np.array_equal(xidx.shape, tidx.shape)
assert np.linalg.norm(xidx - tidx) / np.linalg.norm(xidx) <= 1e-7
def test_squeeze():
for i in range(100):
x = np.random.randint(1, 3, np.random.randint(2, 10))
t = tn.Tensor(x)
x = np.squeeze(x)
t = tn.squeeze(t)
assert np.array_equal(x.shape, t.shape)
def test_slicing():
t = tn.rand([1, 3, 1, 2, 1], ranks_tt=3, ranks_tucker=2)
x = t.numpy()
idx = slice(None)
check(x, t, idx)
idx = (slice(None), slice(1, None))
check(x, t, idx)
idx = (slice(None), slice(0, 2, None), slice(0, 1))
check(x, t, idx)
def test_mixed():
def check_one_tensor(t):
x = t.numpy()
idxs = []
idxs.append(([0, 0, 0], None, None, 3))
idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
idxs.append((0, [0]))
idxs.append(([0], [0]))
idxs.append(([0], None, None, None, 0, 1))
idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append(([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append((slice(None), slice(None), slice(None), 0))
idxs.append((slice(None), slice(None), [0, 1], 0))
idxs.append((0, np.array([0]), None, 0))
idxs.append((slice(None), slice(None), slice(None), slice(None), None))
idxs.append((None, slice(None), slice(None), slice(None), slice(None), None))
idxs.append((None, slice(None), slice(None), slice(None), slice(None)))
for idx in idxs:
check(x, t, idx)
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, ranks_tucker=2))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=None, ranks_tucker=2, ranks_cp=3))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=2, ranks_cp=[None, None, 3, 3]))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[4, None, None], ranks_tucker=[2, None, 2, None], ranks_cp=[None, None, 3, 3]))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=[None, 4, 4], ranks_tucker=2, ranks_cp=[3, None, None, None]))
for i in range(100):
check_one_tensor(random_format([6, 7, 8, 9]))
t = tn.rand([6, 7, 8, 9], ranks_cp=[3, 3, 3, 3])
t.cores[-1] = t.cores[-1].permute(1, 0)[:, :, None]
check_one_tensor(t)
t = tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)
check(t.numpy(), t, 0)
check(t.numpy(), t, [0, 1])
def test_batch():
def check_one_tensor(t):
x = t.numpy()
idxs = []
idxs.append(([0, 0, 0], None, None, 3))
idxs.append(([0, 0, 0, 0, 0], slice(None), None, 0))
idxs.append((0, [0]))
idxs.append(([0], None, None, None, 0, 1))
idxs.append((slice(None), [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]))
idxs.append((slice(None), slice(None), slice(None), 0))
idxs.append((slice(None), slice(None), [0, 1], 0))
idxs.append((0, np.array([0]), None, 0))
idxs.append((slice(None), slice(None), slice(None), slice(None), None))
for idx in idxs:
check(x, t, idx)
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_tucker=3, batch=True))
check_one_tensor(tn.rand([6, 7, 8, 9], ranks_cp=3, batch=True))
with raises(ValueError) as exc_info:
tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[None, ...]
assert exc_info.type is ValueError
with raises(ValueError) as exc_info2:
tn.rand([6, 7, 8, 9], ranks_tt=3, batch=True)[[0], [0]]
assert exc_info2.type is ValueError