Raw File
import tntorch as tn
import pytest
import numpy as np
import torch
torch.set_default_dtype(torch.float64)


torch.manual_seed(1)


def test_construction():
    m = torch.rand(11 * 3, 23 * 2)

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(m, ttm.torch())

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(m, cpm.torch())


def test_tt_multiply():
    m = torch.rand(11 * 3, 23 * 2)
    v = torch.rand(30, 11 * 3)  # Note: batch = 30, 11 * 3 features

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(v @ m, tn.tt_multiply(ttm, v))

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(v @ m, tn.cp_multiply(cpm, v))


def test_trace():
    input_dims = [11, 3]
    output_dims = input_dims
    m = torch.rand(np.prod(input_dims), np.prod(output_dims))
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(torch.trace(m), ttm.trace())
back to top