swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
test_matrix.py
import tntorch as tn
import pytest
import numpy as np
import torch
torch.set_default_dtype(torch.float64)


torch.manual_seed(1)


def test_construction():
    m = torch.rand(11 * 3, 23 * 2)

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(m, ttm.torch())

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(m, cpm.torch())


def test_tt_multiply():
    m = torch.rand(11 * 3, 23 * 2)
    v = torch.rand(30, 11 * 3)  # Note: batch = 30, 11 * 3 features

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(v @ m, tn.tt_multiply(ttm, v))

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(v @ m, tn.cp_multiply(cpm, v))


def test_trace():
    input_dims = [11, 3]
    output_dims = input_dims
    m = torch.rand(np.prod(input_dims), np.prod(output_dims))
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(torch.trace(m), ttm.trace())
back to top