https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
test_matrix.py
import tntorch as tn
import pytest
import numpy as np
import torch
torch.set_default_dtype(torch.float64)


torch.manual_seed(1)


def test_construction():
    m = torch.rand(11 * 3, 23 * 2)

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(m, ttm.torch())

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(m, cpm.torch())


def test_tt_multiply():
    m = torch.rand(11 * 3, 23 * 2)
    v = torch.rand(30, 11 * 3)  # Note: batch = 30, 11 * 3 features

    input_dims = [11, 3]
    output_dims = [23, 2]
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(v @ m, tn.tt_multiply(ttm, v))

    cpm = tn.CPMatrix(m, input_dims=input_dims, output_dims=output_dims, rank=ranks[0])
    assert torch.allclose(v @ m, tn.cp_multiply(cpm, v))


def test_trace():
    input_dims = [11, 3]
    output_dims = input_dims
    m = torch.rand(np.prod(input_dims), np.prod(output_dims))
    ranks = [50]

    ttm = tn.TTMatrix(m, input_dims=input_dims, output_dims=output_dims, ranks=ranks)
    assert torch.allclose(torch.trace(m), ttm.trace())
back to top