Revision b752db6fd8fb9b68003effd545023cb0dd0a9fde authored by Vedanuj Goswami on 19 March 2021, 23:50:52 UTC, committed by Facebook GitHub Bot on 19 March 2021, 23:52:20 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/816

Adding a check to final validation to run only if val is in run_type

Reviewed By: ytsheng, apsdehal

Differential Revision: D27130474

fbshipit-source-id: 2793a1dd3f77d4d56d737781a68973634d271157
1 parent 62a50f1
Raw File
visdial_multi_modal.py
# Copyright (c) Facebook, Inc. and its affiliates.

from mmf.models.pythia import Pythia
from mmf.modules.decoders import VisDialDiscriminator


class VisDialMultiModalModel(Pythia):
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        self._init_text_embedding()
        self._init_image_encoders()
        self._init_image_embeddings()
        self._init_combine_layer()
        self._init_decoder()
        self._init_extras()

    def _init_text_embedding(self):
        parent = super()
        parent._init_text_embedding("text_embeddings", False)
        parent._init_text_embedding("history_embeddings", True)

    def get_optimizer_parameters(self, config):
        # TODO: Update after implementing decoder
        params = [
            {"params": self.img_embeddings_list.parameters()},
            {"params": self.text_embeddings.parameters()},
            {"params": self.multi_modal_combine_layer.parameters()},
            {"params": self.decoder.projection_layer.parameters()},
            {
                "params": self.img_feat_encoders.parameters(),
                "lr": (config.optimizer.params.lr * 0.1),
            },
        ]

        return params

    def _update_text_embedding_args(self, args):
        parent = super()
        parent._update_text_embedding_args(args)
        # Add embedding vectors to args
        args.embedding_vectors = self.config.embedding_vectors

    def _init_decoder(self):
        embedding = self.text_embeddings[0].module
        embedding_dim = self.text_embeddings[0].embedding_dim
        hidden_dim = self.multi_modal_combine_layer.out_dim

        self.decoder = VisDialDiscriminator(
            {"embedding_dim": embedding_dim, "hidden_dim": hidden_dim}, embedding
        )

    def combine_embeddings(self, *args):
        return self.multi_modal_combine_layer(*args)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.decoder(joint_embedding, kwargs)

    def forward(
        self, texts, answer_options, histories, image_features, image_dims, **kwargs
    ):

        texts = texts.view(-1, texts.size(2))
        histories = histories.view(-1, histories.size(2))
        text_embedding_total = self.process_text_embedding(texts)
        histories_total = self.process_text_embedding(histories, "history_embeddings")

        for idx, image_feature in enumerate(image_features):
            feature_size = image_feature.size()[2:]
            image_features[idx] = image_feature.view(-1, *feature_size)

        size = image_dims.size()[2:]
        image_dims = image_dims.view(-1, *size)

        assert len(image_features) == len(
            self.img_feat_encoders
        ), "number of image feature model doesnot equal \
                 to number of image features"

        image_embedding_total = self.process_image_embedding(
            image_features, image_dims, text_embedding_total
        )

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            image_embedding_total, text_embedding_total, histories_total
        )

        decoder_info = {
            "answer_options": answer_options,
            "answer_options_len": kwargs["answer_options_len"],
        }
        return self.calculate_logits(joint_embedding, **decoder_info)
back to top