# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import itertools
import os
import platform
import random
import socket
import tempfile
import unittest
import pytorch_lightning as pl
import torch
from mmf.common.sample import Sample, SampleList
from mmf.utils.general import get_current_device
def compare_tensors(a, b):
return torch.equal(a, b)
def dummy_args(model="cnn_lstm", dataset="clevr"):
args = argparse.Namespace()
args.opts = [f"model={model}", f"dataset={dataset}"]
args.config_override = None
return args
def is_network_reachable():
try:
# check if host name can be resolved
host = socket.gethostbyname("one.one.one.one")
# check if host is actually reachable
s = socket.create_connection((host, 80), 2)
s.close()
return True
except OSError as e:
if e.errno == 101:
pass
return False
NETWORK_AVAILABLE = is_network_reachable()
CUDA_AVAILBLE = torch.cuda.is_available()
def skip_if_no_network(testfn, reason="Network is not available"):
return unittest.skipUnless(NETWORK_AVAILABLE, reason)(testfn)
def skip_if_no_cuda(testfn, reason="Cuda is not available"):
return unittest.skipUnless(CUDA_AVAILBLE, reason)(testfn)
def skip_if_windows(testfn, reason="Doesn't run on Windows"):
return unittest.skipIf("Windows" in platform.system(), reason)(testfn)
def skip_if_macos(testfn, reason="Doesn't run on MacOS"):
return unittest.skipIf("Darwin" in platform.system(), reason)(testfn)
def compare_state_dicts(a, b):
same = True
same = same and (list(a.keys()) == list(b.keys()))
if not same:
return same
for val1, val2 in zip(a.values(), b.values()):
if isinstance(val1, torch.Tensor):
same = same and compare_tensors(val1, val2)
elif not isinstance(val2, torch.Tensor):
same = same and val1 == val2
else:
same = False
if not same:
return same
return same
def build_random_sample_list():
first = Sample()
first.x = random.randint(0, 100)
first.y = torch.rand((5, 4))
first.z = Sample()
first.z.x = random.randint(0, 100)
first.z.y = torch.rand((6, 4))
second = Sample()
second.x = random.randint(0, 100)
second.y = torch.rand((5, 4))
second.z = Sample()
second.z.x = random.randint(0, 100)
second.z.y = torch.rand((6, 4))
return SampleList([first, second])
DATA_ITEM_KEY = "test"
class NumbersDataset(torch.utils.data.Dataset):
def __init__(self, num_examples, data_item_key=DATA_ITEM_KEY):
self.num_examples = num_examples
self.data_item_key = data_item_key
def __getitem__(self, idx):
return {
self.data_item_key: torch.tensor(idx, dtype=torch.float32).unsqueeze(-1)
}
def __len__(self):
return self.num_examples
class SimpleModel(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.linear = torch.nn.Linear(size, 1)
def forward(self, prepared_batch):
input_sample = SampleList(prepared_batch)
batch = prepared_batch[DATA_ITEM_KEY]
output = self.linear(batch)
loss = torch.nn.MSELoss()(-1 * output, batch)
return {"losses": {"loss": loss}, "logits": output, "input_batch": input_sample}
class SimpleLightningModel(pl.LightningModule):
def __init__(self, size, config=None):
super().__init__()
self.model = SimpleModel(size)
self.config = config
def forward(self, prepared_batch):
return self.model(prepared_batch)
def training_step(self, batch, batch_idx, *args, **kwargs):
output = self(batch)
output["loss"] = output["losses"]["loss"]
return output
def configure_optimizers(self):
if self.config is None:
return torch.optim.Adam(self.parameters(), lr=0.01)
else:
from mmf.utils.build import build_lightning_optimizers
return build_lightning_optimizers(self, self.config)
def assertModulesEqual(mod1, mod2):
for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()):
return p1.equal(p2)
def setup_proxy():
# Enable proxy in FB dev env
if not is_network_reachable() and (
os.getenv("SANDCASTLE") == "1"
or os.getenv("TW_JOB_USER") == "sandcastle"
or socket.gethostname().startswith("dev")
or "fbinfra" in socket.gethostname()
):
os.environ["HTTPS_PROXY"] = "http://fwdproxy:8080"
os.environ["HTTP_PROXY"] = "http://fwdproxy:8080"
def compare_torchscript_transformer_models(model, vocab_size):
test_sample = Sample()
test_sample.input_ids = torch.randint(low=0, high=vocab_size, size=(128,)).long()
test_sample.input_mask = torch.ones(128).long()
test_sample.segment_ids = torch.zeros(128).long()
test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
test_sample.image = torch.rand((3, 300, 300)).float()
test_sample_list = SampleList([test_sample])
model = model.to(get_current_device())
test_sample_list = test_sample_list.to(get_current_device())
with torch.no_grad():
model_output = model(test_sample_list)
script_model = torch.jit.script(model)
with torch.no_grad():
script_output = script_model(test_sample_list)
return torch.equal(model_output["scores"], script_output["scores"])
def verify_torchscript_models(model):
model.eval()
script_model = torch.jit.script(model)
with tempfile.NamedTemporaryFile(delete=False) as tmp:
torch.jit.save(script_model, tmp)
loaded_model = torch.jit.load(tmp.name)
return assertModulesEqual(script_model, loaded_model)