swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Exact method for moments
Tip revision: 241bf7a
test_matrix.py
import tntorch as tn
import pytest
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
torch.manual_seed(1)
def test_construction():
m = torch.rand(11 * 3, 23 * 2)
input_dims = [11, 3]
output_dims = [23, 2]
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(m, ttm.torch())
cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
assert torch.allclose(m, cpm.torch())
def test_tt_multiply():
m = torch.rand(11 * 3, 23 * 2)
v = torch.rand(30, 11 * 3) # Note: batch = 30, 11 * 3 features
input_dims = [11, 3]
output_dims = [23, 2]
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(v @ m, tn.tt_multiply(ttm, v))
cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
assert torch.allclose(v @ m, tn.cp_multiply(cpm, v))
def test_trace():
input_dims = [11, 3]
output_dims = input_dims
m = torch.rand(np.prod(input_dims), np.prod(output_dims))
ranks = [50]
ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
assert torch.allclose(torch.trace(m), ttm.trace())