Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
test_matrix.py
import tntorch as tn
import pytest
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
torch.manual_seed(1)
def test_construction():
m = torch.rand(11 * 3, 23 * 2)
input_dims = [11, 3]
output_dims = [23, 2]
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(m, ttm.torch())
cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
assert torch.allclose(m, cpm.torch())
def test_tt_multiply():
m = torch.rand(11 * 3, 23 * 2)
v = torch.rand(30, 11 * 3) # Note: batch = 30, 11 * 3 features
input_dims = [11, 3]
output_dims = [23, 2]
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(v @ m, tn.tt_multiply(ttm, v))
cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
assert torch.allclose(v @ m, tn.cp_multiply(cpm, v))
def test_trace():
input_dims = [11, 3]
output_dims = input_dims
m = torch.rand(np.prod(input_dims), np.prod(output_dims))
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(torch.trace(m), ttm.trace())
Computing file changes ...