Revision 6db9048c848a178872d1aa1a14ee0009de703750 authored by Vedanuj Goswami on 23 February 2021, 04:55:02 UTC, committed by Facebook GitHub Bot on 23 February 2021, 04:56:33 UTC
Summary: Pull Request resolved: https://github.com/facebookresearch/mmf/pull/785 Some cleanup of file_io file Reviewed By: apsdehal, BruceChaun Differential Revision: D26599901 fbshipit-source-id: 1979248b54ec0d5b2566d158cae4a72028b1f116
1 parent b1fd2a9
test_sample.py
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest
import tests.test_utils as test_utils
import torch
from mmf.common.sample import Sample, to_device
class TestSample(unittest.TestCase):
def test_sample_working(self):
initial = Sample()
initial.x = 1
initial["y"] = 2
# Assert setter and getter
self.assertEqual(initial.x, 1)
self.assertEqual(initial["x"], 1)
self.assertEqual(initial.y, 2)
self.assertEqual(initial["y"], 2)
update_dict = {"a": 3, "b": {"c": 4}}
initial.update(update_dict)
self.assertEqual(initial.a, 3)
self.assertEqual(initial["a"], 3)
self.assertEqual(initial.b.c, 4)
self.assertEqual(initial["b"].c, 4)
class TestSampleList(unittest.TestCase):
@test_utils.skip_if_no_cuda
def test_pin_memory(self):
sample_list = test_utils.build_random_sample_list()
sample_list.pin_memory()
pin_list = [sample_list.y, sample_list.z.y]
non_pin_list = [sample_list.x, sample_list.z.x]
all_pinned = True
for pin in pin_list:
all_pinned = all_pinned and pin.is_pinned()
self.assertTrue(all_pinned)
any_pinned = False
for pin in non_pin_list:
any_pinned = any_pinned or (hasattr(pin, "is_pinned") and pin.is_pinned())
self.assertFalse(any_pinned)
def test_to_dict(self):
sample_list = test_utils.build_random_sample_list()
sample_dict = sample_list.to_dict()
self.assertTrue(isinstance(sample_dict, dict))
# hasattr won't work anymore
self.assertFalse(hasattr(sample_dict, "x"))
keys_to_assert = ["x", "y", "z", "z.x", "z.y"]
all_keys = True
for key in keys_to_assert:
current = sample_dict
if "." in key:
sub_keys = key.split(".")
for sub_key in sub_keys:
all_keys = all_keys and sub_key in current
current = current[sub_key]
else:
all_keys = all_keys and key in current
self.assertTrue(all_keys)
self.assertTrue(isinstance(sample_dict, dict))
class TestFunctions(unittest.TestCase):
def test_to_device(self):
sample_list = test_utils.build_random_sample_list()
modified = to_device(sample_list, "cpu")
self.assertEqual(modified.get_device(), torch.device("cpu"))
modified = to_device(sample_list, torch.device("cpu"))
self.assertEqual(modified.get_device(), torch.device("cpu"))
modified = to_device(sample_list, "cuda")
if torch.cuda.is_available():
self.assertEqual(modified.get_device(), torch.device("cuda:0"))
else:
self.assertEqual(modified.get_device(), torch.device("cpu"))
double_modified = to_device(modified, modified.get_device())
self.assertTrue(double_modified is modified)
custom_batch = [{"a": 1}]
self.assertEqual(to_device(custom_batch), custom_batch)
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...