Revision 6db9048c848a178872d1aa1a14ee0009de703750 authored by Vedanuj Goswami on 23 February 2021, 04:55:02 UTC, committed by Facebook GitHub Bot on 23 February 2021, 04:56:33 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/785

Some cleanup of file_io file

Reviewed By: apsdehal, BruceChaun

Differential Revision: D26599901

fbshipit-source-id: 1979248b54ec0d5b2566d158cae4a72028b1f116
1 parent b1fd2a9
Raw File
test_mmbt.py
# Copyright (c) Facebook, Inc. and its affiliates.

import unittest

import tests.test_utils as test_utils
import torch
from mmf.common.sample import Sample, SampleList
from mmf.models.mmbt import MMBT
from mmf.modules.encoders import (
    ImageEncoderFactory,
    ImageEncoderTypes,
    ResNet152ImageEncoder,
    TextEncoderFactory,
    TextEncoderTypes,
)
from mmf.utils.build import build_model
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports
from omegaconf import OmegaConf


class TestMMBTTorchscript(unittest.TestCase):
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "mmbt"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        model_config.model = model_name
        self.finetune_model = build_model(model_config)

    def test_load_save_finetune_model(self):
        self.assertTrue(test_utils.verify_torchscript_models(self.finetune_model))

    def test_finetune_model(self):
        self.finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255, size=(128,)).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = self.finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(self.finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(torch.equal(model_output["scores"], script_output["scores"]))

    def test_modal_end_token(self):
        self.finetune_model.eval()

        # Suppose 0 for <cls>, 1 for <pad> 2 for <sep>
        CLS = 0
        PAD = 1
        SEP = 2
        size = 128

        input_ids = torch.randint(low=0, high=30255, size=(size,)).long()
        input_mask = torch.ones(size).long()

        input_ids[0] = CLS
        length = torch.randint(low=2, high=size - 1, size=(1,))
        input_ids[length] = SEP
        input_ids[length + 1 :] = PAD
        input_mask[length + 1 :] = 0

        test_sample = Sample()
        test_sample.input_ids = input_ids.clone()
        test_sample.input_mask = input_mask.clone()
        test_sample.segment_ids = torch.zeros(size).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample])

        mmbt_base = self.finetune_model.model.bert
        with torch.no_grad():
            actual_modal_end_token = mmbt_base.extract_modal_end_token(test_sample_list)

        expected_modal_end_token = torch.zeros([1]).fill_(SEP).long()
        self.assertTrue(torch.equal(actual_modal_end_token, expected_modal_end_token))
        self.assertTrue(torch.equal(test_sample_list.input_ids[0, :-1], input_ids[1:]))
        self.assertTrue(
            torch.equal(test_sample_list.input_mask[0, :-1], input_mask[1:])
        )


class TestMMBTConfig(unittest.TestCase):
    def test_mmbt_from_params(self):
        # default init
        mmbt = MMBT.from_params(
            modal_encoder=ImageEncoderFactory.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
            text_encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
        )

        config = OmegaConf.structured(
            MMBT.Config(
                modal_encoder=ImageEncoderFactory.Config(
                    type=ImageEncoderTypes.resnet152,
                    params=ResNet152ImageEncoder.Config(pretrained=False),
                ),
                text_encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
            )
        )
        self.assertIsNotNone(mmbt)
        # Make sure that the config is created from MMBT.Config
        self.assertEqual(mmbt.config, config)

    def test_mmbt_pretrained(self):
        test_utils.setup_proxy()
        mmbt = MMBT.from_params()
        self.assertIsNotNone(mmbt)

    def test_mmbt_directly_from_config(self):
        config = OmegaConf.structured(
            MMBT.Config(
                modal_encoder=ImageEncoderFactory.Config(
                    type=ImageEncoderTypes.resnet152,
                    params=ResNet152ImageEncoder.Config(pretrained=False),
                ),
                text_encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
            )
        )
        mmbt = MMBT(config)
        self.assertIsNotNone(mmbt)
        # Make sure that the config is created from MMBT.Config
        self.assertEqual(mmbt.config, config)
back to top