Revision 6db9048c848a178872d1aa1a14ee0009de703750 authored by Vedanuj Goswami on 23 February 2021, 04:55:02 UTC, committed by Facebook GitHub Bot on 23 February 2021, 04:56:33 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/785

Some cleanup of file_io file

Reviewed By: apsdehal, BruceChaun

Differential Revision: D26599901

fbshipit-source-id: 1979248b54ec0d5b2566d158cae4a72028b1f116
1 parent b1fd2a9
Raw File
test_utils.py
# Copyright (c) Facebook, Inc. and its affiliates.

import argparse
import itertools
import os
import platform
import random
import socket
import tempfile
import unittest

import pytorch_lightning as pl
import torch
from mmf.common.sample import Sample, SampleList
from mmf.utils.general import get_current_device


def compare_tensors(a, b):
    return torch.equal(a, b)


def dummy_args(model="cnn_lstm", dataset="clevr"):
    args = argparse.Namespace()
    args.opts = [f"model={model}", f"dataset={dataset}"]
    args.config_override = None
    return args


def is_network_reachable():
    try:
        # check if host name can be resolved
        host = socket.gethostbyname("one.one.one.one")
        # check if host is actually reachable
        s = socket.create_connection((host, 80), 2)
        s.close()
        return True
    except OSError as e:
        if e.errno == 101:
            pass
    return False


NETWORK_AVAILABLE = is_network_reachable()
CUDA_AVAILBLE = torch.cuda.is_available()


def skip_if_no_network(testfn, reason="Network is not available"):
    return unittest.skipUnless(NETWORK_AVAILABLE, reason)(testfn)


def skip_if_no_cuda(testfn, reason="Cuda is not available"):
    return unittest.skipUnless(CUDA_AVAILBLE, reason)(testfn)


def skip_if_windows(testfn, reason="Doesn't run on Windows"):
    return unittest.skipIf("Windows" in platform.system(), reason)(testfn)


def skip_if_macos(testfn, reason="Doesn't run on MacOS"):
    return unittest.skipIf("Darwin" in platform.system(), reason)(testfn)


def compare_state_dicts(a, b):
    same = True
    same = same and (list(a.keys()) == list(b.keys()))
    if not same:
        return same

    for val1, val2 in zip(a.values(), b.values()):
        if isinstance(val1, torch.Tensor):
            same = same and compare_tensors(val1, val2)
        elif not isinstance(val2, torch.Tensor):
            same = same and val1 == val2
        else:
            same = False
        if not same:
            return same

    return same


def build_random_sample_list():
    first = Sample()
    first.x = random.randint(0, 100)
    first.y = torch.rand((5, 4))
    first.z = Sample()
    first.z.x = random.randint(0, 100)
    first.z.y = torch.rand((6, 4))

    second = Sample()
    second.x = random.randint(0, 100)
    second.y = torch.rand((5, 4))
    second.z = Sample()
    second.z.x = random.randint(0, 100)
    second.z.y = torch.rand((6, 4))

    return SampleList([first, second])


DATA_ITEM_KEY = "test"


class NumbersDataset(torch.utils.data.Dataset):
    def __init__(self, num_examples, data_item_key=DATA_ITEM_KEY):
        self.num_examples = num_examples
        self.data_item_key = data_item_key

    def __getitem__(self, idx):
        return {
            self.data_item_key: torch.tensor(idx, dtype=torch.float32).unsqueeze(-1)
        }

    def __len__(self):
        return self.num_examples


class SimpleModel(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linear = torch.nn.Linear(size, 1)

    def forward(self, prepared_batch):
        input_sample = SampleList(prepared_batch)
        batch = prepared_batch[DATA_ITEM_KEY]
        output = self.linear(batch)
        loss = torch.nn.MSELoss()(-1 * output, batch)
        return {"losses": {"loss": loss}, "logits": output, "input_batch": input_sample}


class SimpleLightningModel(pl.LightningModule):
    def __init__(self, size, config=None):
        super().__init__()
        self.model = SimpleModel(size)
        self.config = config

    def forward(self, prepared_batch):
        return self.model(prepared_batch)

    def training_step(self, batch, batch_idx, *args, **kwargs):
        output = self(batch)
        output["loss"] = output["losses"]["loss"]
        return output

    def configure_optimizers(self):
        if self.config is None:
            return torch.optim.Adam(self.parameters(), lr=0.01)
        else:
            from mmf.utils.build import build_lightning_optimizers

            return build_lightning_optimizers(self, self.config)


def assertModulesEqual(mod1, mod2):
    for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()):
        return p1.equal(p2)


def setup_proxy():
    # Enable proxy in FB dev env
    if not is_network_reachable() and (
        os.getenv("SANDCASTLE") == "1"
        or os.getenv("TW_JOB_USER") == "sandcastle"
        or socket.gethostname().startswith("dev")
        or "fbinfra" in socket.gethostname()
    ):
        os.environ["HTTPS_PROXY"] = "http://fwdproxy:8080"
        os.environ["HTTP_PROXY"] = "http://fwdproxy:8080"


def compare_torchscript_transformer_models(model, vocab_size):
    test_sample = Sample()
    test_sample.input_ids = torch.randint(low=0, high=vocab_size, size=(128,)).long()
    test_sample.input_mask = torch.ones(128).long()
    test_sample.segment_ids = torch.zeros(128).long()
    test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
    test_sample.image = torch.rand((3, 300, 300)).float()
    test_sample_list = SampleList([test_sample])

    model = model.to(get_current_device())
    test_sample_list = test_sample_list.to(get_current_device())

    with torch.no_grad():
        model_output = model(test_sample_list)

    script_model = torch.jit.script(model)
    with torch.no_grad():
        script_output = script_model(test_sample_list)

    return torch.equal(model_output["scores"], script_output["scores"])


def verify_torchscript_models(model):
    model.eval()
    script_model = torch.jit.script(model)
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        torch.jit.save(script_model, tmp)
        loaded_model = torch.jit.load(tmp.name)
    return assertModulesEqual(script_model, loaded_model)
back to top