Revision 800a086c0e62d3535acf7bcd09f34853e4948492 authored by Madian Khabsa on 24 March 2021, 03:21:27 UTC, committed by Facebook GitHub Bot on 24 March 2021, 03:23:02 UTC
Summary:
Pull Request resolved: https://github.com/fairinternal/mmf-internal/pull/150

- Modifies test reporter to be fetched/registered in registry
- Add hive test reporter for internal support

Reviewed By: apsdehal

Differential Revision: D26954690

fbshipit-source-id: b8ad088d829a5353c713cb493c38f27099b43fcd
1 parent e5edc48
Raw File
test_utils.py
# Copyright (c) Facebook, Inc. and its affiliates.

import argparse
import contextlib
import itertools
import json
import os
import platform
import random
import socket
import tempfile
import unittest
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional

import torch
from mmf.common.sample import Sample, SampleList
from mmf.models.base_model import BaseModel
from mmf.utils.general import get_current_device
from omegaconf import OmegaConf
from torch import Tensor


def compare_tensors(a, b):
    return torch.equal(a, b)


def dummy_args(model="cnn_lstm", dataset="clevr"):
    args = argparse.Namespace()
    args.opts = [f"model={model}", f"dataset={dataset}"]
    args.config_override = None
    return args


def is_network_reachable():
    try:
        # check if host name can be resolved
        host = socket.gethostbyname("one.one.one.one")
        # check if host is actually reachable
        s = socket.create_connection((host, 80), 2)
        s.close()
        return True
    except OSError as e:
        if e.errno == 101:
            pass
    return False


NETWORK_AVAILABLE = is_network_reachable()
CUDA_AVAILBLE = torch.cuda.is_available()


def is_fb():
    return (
        os.getenv("SANDCASTLE") == "1"
        or os.getenv("TW_JOB_USER") == "sandcastle"
        or (
            socket.gethostname().startswith("dev")
            and not socket.gethostname().startswith("devfair")
        )
        or "fbinfra" in socket.gethostname()
    )


def skip_if_no_network(testfn, reason="Network is not available"):
    return unittest.skipUnless(NETWORK_AVAILABLE, reason)(testfn)


def skip_if_no_cuda(testfn, reason="Cuda is not available"):
    return unittest.skipUnless(CUDA_AVAILBLE, reason)(testfn)


def skip_if_windows(testfn, reason="Doesn't run on Windows"):
    return unittest.skipIf("Windows" in platform.system(), reason)(testfn)


def skip_if_macos(testfn, reason="Doesn't run on MacOS"):
    return unittest.skipIf("Darwin" in platform.system(), reason)(testfn)


def skip_if_non_fb(testfn, reason="Doesn't run on non FB infra"):
    return unittest.skipUnless(is_fb(), reason)(testfn)


def compare_state_dicts(a, b):
    same = True
    same = same and (list(a.keys()) == list(b.keys()))
    if not same:
        return same

    for val1, val2 in zip(a.values(), b.values()):
        if isinstance(val1, torch.Tensor):
            same = same and compare_tensors(val1, val2)
        elif not isinstance(val2, torch.Tensor):
            same = same and val1 == val2
        else:
            same = False
        if not same:
            return same

    return same


@contextlib.contextmanager
def make_temp_dir():
    temp_dir = tempfile.TemporaryDirectory()
    try:
        yield temp_dir.name
    finally:
        # Don't clean up on Windows, as it always results in an error
        if "Windows" not in platform.system():
            temp_dir.cleanup()


def build_random_sample_list():
    first = Sample()
    first.x = random.randint(0, 100)
    first.y = torch.rand((5, 4))
    first.z = Sample()
    first.z.x = random.randint(0, 100)
    first.z.y = torch.rand((6, 4))

    second = Sample()
    second.x = random.randint(0, 100)
    second.y = torch.rand((5, 4))
    second.z = Sample()
    second.z.x = random.randint(0, 100)
    second.z.y = torch.rand((6, 4))

    return SampleList([first, second])


DATA_ITEM_KEY = "test"


class NumbersDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        num_examples: int,
        data_item_key: str = DATA_ITEM_KEY,
        always_one: bool = False,
    ):
        self.num_examples = num_examples
        self.data_item_key = data_item_key
        self.always_one = always_one

    def __getitem__(self, idx: int) -> Sample:
        sample = Sample()
        sample[self.data_item_key] = torch.tensor(idx, dtype=torch.float32).unsqueeze(
            -1
        )
        if self.always_one:
            sample.targets = torch.tensor(0, dtype=torch.long)
        return sample

    def __len__(self) -> int:
        return self.num_examples


class SimpleModel(BaseModel):
    @dataclass
    class Config(BaseModel.Config):
        in_dim: int = 1
        out_dim: int = 1
        data_item_key: str = DATA_ITEM_KEY

    def __init__(self, config: Config, *args, **kwargs):
        config = OmegaConf.merge(OmegaConf.structured(self.Config), config)
        super().__init__(config)
        self.data_item_key = config.data_item_key

    def build(self):
        self.classifier = torch.nn.Linear(self.config.in_dim, self.config.out_dim)

    def forward(self, prepared_batch: Dict[str, Tensor]):
        input_sample = SampleList(prepared_batch)
        batch = prepared_batch[self.data_item_key]
        output = self.classifier(batch)
        loss = torch.nn.MSELoss()(-1 * output, batch)
        return {"losses": {"loss": loss}, "logits": output, "input_batch": input_sample}


class SimpleLightningModel(SimpleModel):
    def __init__(self, config: SimpleModel.Config, trainer_config=None):
        super().__init__(config)
        self.trainer_config = trainer_config

    def training_step(self, batch, batch_idx, *args, **kwargs):
        output = self(batch)
        output["loss"] = output["losses"]["loss"]
        return output

    def configure_optimizers(self):
        if self.config is None:
            return torch.optim.Adam(self.parameters(), lr=0.01)
        else:
            from mmf.utils.build import build_lightning_optimizers

            return build_lightning_optimizers(self, self.trainer_config)


def assertModulesEqual(mod1, mod2):
    for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()):
        return p1.equal(p2)


def setup_proxy():
    # Enable proxy in FB dev env
    if not is_network_reachable() and (
        os.getenv("SANDCASTLE") == "1"
        or os.getenv("TW_JOB_USER") == "sandcastle"
        or socket.gethostname().startswith("dev")
        or "fbinfra" in socket.gethostname()
    ):
        os.environ["HTTPS_PROXY"] = "http://fwdproxy:8080"
        os.environ["HTTP_PROXY"] = "http://fwdproxy:8080"


def compare_torchscript_transformer_models(model, vocab_size):
    test_sample = Sample()
    test_sample.input_ids = torch.randint(low=0, high=vocab_size, size=(128,)).long()
    test_sample.input_mask = torch.ones(128).long()
    test_sample.segment_ids = torch.zeros(128).long()
    test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
    test_sample.image = torch.rand((3, 300, 300)).float()
    test_sample_list = SampleList([test_sample])

    model = model.to(get_current_device())
    test_sample_list = test_sample_list.to(get_current_device())

    with torch.no_grad():
        model_output = model(test_sample_list)

    script_model = torch.jit.script(model)
    with torch.no_grad():
        script_output = script_model(test_sample_list)

    return torch.equal(model_output["scores"], script_output["scores"])


def verify_torchscript_models(model):
    model.eval()
    script_model = torch.jit.script(model)
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        torch.jit.save(script_model, tmp)
        loaded_model = torch.jit.load(tmp.name)
    return assertModulesEqual(script_model, loaded_model)


def search_log(log_file: str, search_condition: Optional[List[Callable]] = None):
    """Searches a log file for a particular search conditions which can be list
    of functions and returns it back

    Args:
        log_file (str): Log file in which search needs to be performed
        search_condition (List[Callable], optional): Search conditions in form of list.
            Each corresponding to a function to test a condition. Defaults to None.

    Returns:
        JSONObject: Json representation of the search line

    Throws:
        AssertionError: If no log line is found meeting the conditions
    """
    if search_condition is None:
        search_condition = {}

    lines = []

    with open(log_file) as f:
        lines = f.readlines()

    filtered_line = None
    for line in lines:
        line = line.strip()
        if "progress" not in line:
            continue
        info_index = line.find(" : ")
        line = line[info_index + 3 :]
        res = json.loads(line)

        meets_condition = True
        for condition_fn in search_condition:
            meets_condition = meets_condition and condition_fn(res)

        if meets_condition:
            filtered_line = res
            break

    assert filtered_line is not None, "No match for search condition in log file"
    return filtered_line
back to top