https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
test_interpolation.py
import numpy as np
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)


def test_als_completion():
    I = 8
    train_x = torch.arange(I)[:, None].repeat(1, 2)
    train_y = torch.ones(I)
    t = tn.als_completion(train_x, train_y, ranks_tt=3)
    assert tn.relative_error(train_y, t[train_x]) < 1e-5
back to top