swh:1:snp:39d5aa88ec19187d5f6c7d91d5a7e02d6ac7c2c2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
test_gpu.py
import tntorch as tn
import torch
torch.set_default_dtype(torch.float64)

# in case the computer testing has no gpu, tests will just pass
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def test_tt():
    X = torch.randn(16, 16, 16)
    y1 = tn.Tensor(X, ranks_tt=3).torch()
    y2 = tn.Tensor(X, ranks_tt=3, device=device).torch().cpu()
    assert torch.abs(y1-y2).max() < 1e-5


def test_tucker():
    X = torch.randn(16, 16, 16)
    y1 = tn.Tensor(X, ranks_tucker=3).torch()
    y2 = tn.Tensor(X, ranks_tucker=3, device=device).torch().cpu()
    assert torch.abs(y1-y2).max() < 1e-5


def test_cp():
    X = torch.randn(16, 16, 16)
    y1 = tn.Tensor(X, ranks_cp=3).torch()
    y2 = tn.Tensor(X, ranks_cp=3, device=device).torch().cpu()
    assert torch.abs(y1-y2).max() < 1e-5
back to top