Raw File
test_configuration.py
# Copyright (c) Facebook, Inc. and its affiliates.
import gc
import os
import unittest

from mmf.utils.configuration import Configuration, get_zoo_config
from mmf.utils.env import setup_imports, teardown_imports
from mmf.utils.general import get_mmf_root
from tests.test_utils import dummy_args


class TestUtilsConfiguration(unittest.TestCase):
    def setUp(self):
        setup_imports()

    def tearDown(self):
        teardown_imports()
        gc.collect()

    def test_get_zoo_config(self):
        # Test direct key
        version, resources = get_zoo_config("textvqa.ocr_en")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test default variation
        version, resources = get_zoo_config("textvqa")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test non-default variation
        version, resources = get_zoo_config("textvqa", variation="ocr_en")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test random key
        version, resources = get_zoo_config("some_random")
        self.assertIsNone(version)
        self.assertIsNone(resources)

        # Test non-existent variation
        self.assertRaises(
            AssertionError, get_zoo_config, "textvqa", variation="some_random"
        )

        # Test different zoo_type
        version, resources = get_zoo_config("visual_bert.pretrained", zoo_type="models")
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

        # Test direct config
        version, resources = get_zoo_config(
            "visual_bert.pretrained",
            zoo_config_path=os.path.join("configs", "zoo", "models.yaml"),
        )
        self.assertIsNotNone(version)
        self.assertIsNotNone(resources)

    def test_config_overrides(self):
        config_path = os.path.join(
            get_mmf_root(),
            "..",
            "projects",
            "m4c",
            "configs",
            "textvqa",
            "defaults.yaml",
        )
        config_path = os.path.abspath(config_path)
        args = dummy_args(model="m4c", dataset="textvqa")
        args.opts += [
            f"config={config_path}",
            "training.lr_steps[1]=10000",
            'dataset_config.textvqa.zoo_requirements[0]="test"',
        ]
        configuration = Configuration(args)
        configuration.freeze()
        config = configuration.get_config()
        self.assertEqual(config.training.lr_steps[1], 10000)
        self.assertEqual(config.dataset_config.textvqa.zoo_requirements[0], "test")
back to top