Raw File
test_configs_for_keys.py
# Copyright (c) Facebook, Inc. and its affiliates.
import contextlib
import gc
import unittest
import warnings
from io import StringIO

from mmf.common.registry import registry
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports, teardown_imports
from tests.test_utils import dummy_args


class TestConfigsForKeys(unittest.TestCase):
    def setUp(self):
        setup_imports()

    def tearDown(self):
        teardown_imports()
        gc.collect()

    def test_model_configs_for_keys(self):
        models_mapping = registry.mapping["model_name_mapping"]

        for model_key, model_cls in models_mapping.items():
            if model_cls.config_path() is None:
                warnings.warn(
                    (
                        "Model {} has no default configuration defined. "
                        + "Skipping it. Make sure it is intentional"
                    ).format(model_key)
                )
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(model=model_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()

                if model_key == "mmft":
                    continue

                self.assertTrue(
                    model_key in config.model_config,
                    "Key for model {} doesn't exists in its configuration".format(
                        model_key
                    ),
                )

    def test_dataset_configs_for_keys(self):
        builder_name = registry.mapping["builder_name_mapping"]

        for builder_key, builder_cls in builder_name.items():
            if builder_cls.config_path() is None:
                warnings.warn(
                    (
                        "Dataset {} has no default configuration defined. "
                        + "Skipping it. Make sure it is intentional"
                    ).format(builder_key)
                )
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(dataset=builder_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()
                self.assertTrue(
                    builder_key in config.dataset_config,
                    "Key for dataset {} doesn't exists in its configuration".format(
                        builder_key
                    ),
                )
back to top