Revision b752db6fd8fb9b68003effd545023cb0dd0a9fde authored by Vedanuj Goswami on 19 March 2021, 23:50:52 UTC, committed by Facebook GitHub Bot on 19 March 2021, 23:52:20 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/816

Adding a check to final validation to run only if val is in run_type

Reviewed By: ytsheng, apsdehal

Differential Revision: D27130474

fbshipit-source-id: 2793a1dd3f77d4d56d737781a68973634d271157
1 parent 62a50f1
Raw File
mmf_bert.py
# Copyright (c) Facebook, Inc. and its affiliates.

import torch
from mmf.common.registry import registry
from mmf.models.pythia import Pythia
from mmf.modules.embeddings import ProjectionEmbedding
from mmf.utils.transform import transform_to_batch_sequence
from torch import nn
from transformers.modeling_bert import (
    BertConfig,
    BertEmbeddings,
    BertForPreTraining,
    BertPooler,
    BertPredictionHeadTransform,
    BertPreTrainingHeads,
)


@registry.register_model("mmf_bert")
class MMFBert(Pythia):
    def __init__(self, config):
        super().__init__(config)

    @classmethod
    def config_path(cls):
        return "configs/models/mmf_bert/defaults.yaml"

    def build(self):
        super().build()
        self.tie_weights()

        if getattr(self.config, "freeze_base", False):
            for n, p in self.named_parameters():
                if "classifier" not in n:
                    p.requires_grad = False

    def _build_word_embedding(self):
        self.bert_config = BertConfig.from_pretrained(self.config.bert_model_name)
        if self.config.pretrained_bert:
            bert_model = BertForPreTraining.from_pretrained(self.config.bert_model_name)
            self.word_embedding = bert_model.bert.embeddings
            self.pooler = bert_model.bert.pooler
            self.pooler.apply(self.init_weights)

        else:
            self.pooler = BertPooler(self.bert_config)
            self.word_embedding = BertEmbeddings(self.bert_config)

    def _init_classifier(self, hidden_size):
        if "pretraining" in self.config.training_head_type:
            self.classifier = BertPreTrainingHeads(self.bert_config)
        if "vqa" in self.config.training_head_type:
            self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
            self.answer_space_size = 3129
            self.classifier = nn.Sequential(
                BertPredictionHeadTransform(self.bert_config),
                nn.Linear(self.bert_config.hidden_size, self.answer_space_size),
            )
            # self.classifier = nn.Linear(self.bert_config.hidden_size,
            # self.answer_space_size)
        elif "vizwiz" in self.config.training_head_type:
            self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
            self.answer_space_size = 7371
            self.classifier = nn.Sequential(
                BertPredictionHeadTransform(self.bert_config),
                nn.Linear(self.bert_config.hidden_size, self.answer_space_size),
            )
            # self.classifier = nn.Linear(self.bert_config.hidden_size,
            # self.answer_space_size)
        elif self.config.training_head_type == "visual_entailment":
            self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
            self.classifier = nn.Sequential(
                BertPredictionHeadTransform(self.bert_config),
                nn.Linear(self.bert_config.hidden_size, 3),
            )
            # self.classifier = nn.Linear(self.bert_config.hidden_size, 3)

    def _init_text_embeddings(self, attr="text"):
        self.text_embeddings_out_dim = self.bert_config.hidden_size
        self.text_embedding = nn.MultiheadAttention(**self.config.text_embeddings[0])

    def _tie_or_clone_weights(self, first_module, second_module):
        """Tie or clone module weights depending of weither we are using
        TorchScript or not
        """
        if self.config.torchscript:
            first_module.weight = nn.Parameter(second_module.weight.clone())
        else:
            first_module.weight = second_module.weight

    def tie_weights(self):
        """Make sure we are sharing the input and output embeddings.
        Export to TorchScript can't handle parameter sharing so we are cloning
        them instead.
        """
        if hasattr(self, "cls"):
            self._tie_or_clone_weights(
                self.cls.predictions.decoder, self.word_embeddings.word_embeddings
            )

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(getattr(self.config, f"{attr}_feature_encodings"))

        self.image_feature_projection = ProjectionEmbedding(
            **self.config.image_feature_projection
        )
        self.feature_embeddings_out_dim = 0

        if self.config.image_intra_attention:
            self.image_feature_intra_attention = nn.MultiheadAttention(
                **self.config.image_feature_attentions[0]
            )

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = nn.MultiheadAttention(**feature_attn_model_params)
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_attn_model_params[
                    "embed_dim"
                ]

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        setattr(
            self, attr + "_feature_embeddings_out_dim", self.feature_embeddings_out_dim
        )
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def get_optimizer_parameters(self, config):
        param_optimizer = list(self.named_parameters())
        image_feature_encoders_params = [
            n for n in param_optimizer if "image_feature_encoders" in n[0]
        ]
        param_optimizer = [
            n for n in param_optimizer if "image_feature_encoders" not in n[0]
        ]

        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.01,
            },
            {
                "params": [
                    p for n, p in param_optimizer if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
            {
                "params": [p for _, p in image_feature_encoders_params],
                "lr": (config.optimizer.params.lr * 0.1),
                "weight_decay": 0.01,
            },
        ]

        return optimizer_grouped_parameters

    # WARNING(ASG): This doesn't have finetune_lr_multiplier option enabled yet

    def process_text_embedding(self, text_embedding, key_padding_mask=None):
        text_embedding = text_embedding.transpose(0, 1)
        embedding, _ = self.text_embedding(
            text_embedding,
            text_embedding,
            text_embedding,
            key_padding_mask=key_padding_mask,
        )

        return embedding.transpose(0, 1)

    def process_feature_embedding(
        self,
        attr,
        sample_list,
        text_embedding_total,
        key_padding_mask=None,
        attn_mask=None,
        extra=None,
        batch_size_t=None,
    ):
        if extra is None:
            extra = []
        feature_embeddings = []
        feature_attentions = []
        features = []
        batch_size_t = (
            sample_list.get_batch_size() if batch_size_t is None else batch_size_t
        )

        # Convert list of keys to the actual values
        extra = sample_list.get_fields(extra)

        feature_idx = 0

        # Get all of the features, which are in the form, "image_feature_0"
        # "image_feature_1" ...
        while True:
            feature = getattr(sample_list, f"{attr}_feature_{feature_idx:d}", None)
            if feature is None:
                break
            feature_idx += 1
            feature = feature[:batch_size_t]
            features.append(feature)

        feature_encoders = getattr(self, attr + "_feature_encoders")
        # Each feature should have a separate image feature encoders
        assert len(features) == len(feature_encoders), (
            "Number of feature encoders, {} are not equal "
            "to number of features, {}.".format(len(feature_encoders), len(features))
        )

        # Now, iterate to get final attended image features
        for i, feature in enumerate(features):
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, f"{attr}_info_{i:d}", {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]

            # Attribute in which encoders are saved, for "image" it
            # will be "image_feature_encoders", other example is
            # "context_feature_encoders"
            encoders_attr = attr + "_feature_encoders"
            feature_encoder = getattr(self, encoders_attr)[i]

            # Encode the features
            encoded_feature = feature_encoder(feature)

            # Get all of the feature embeddings
            list_attr = attr + "_feature_embeddings_list"
            feature_embedding_models = getattr(self, list_attr)[i]
            encoded_feature = self.image_feature_projection(encoded_feature)
            encoded_feature = encoded_feature.transpose(0, 1)
            text_embedding_total = text_embedding_total.transpose(0, 1)

            if self.config.image_intra_attention:
                encoded_feature, _ = self.image_feature_intra_attention(
                    encoded_feature,
                    encoded_feature,
                    encoded_feature,
                    key_padding_mask=key_padding_mask,
                    attn_mask=attn_mask,
                )
            # Forward through these embeddings one by one
            for feature_embedding_model in feature_embedding_models:
                inp = (text_embedding_total, encoded_feature, encoded_feature)
                embedding, attention = feature_embedding_model(
                    *inp, key_padding_mask=key_padding_mask, attn_mask=attn_mask
                )
                feature_embeddings.append(embedding.transpose(0, 1))
                feature_attentions.append(attention.squeeze(-1))

        # Concatenate all features embeddings and return along with attention
        feature_embedding_total = torch.cat(feature_embeddings, dim=1)
        return feature_embedding_total, feature_attentions

    def init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal
            # for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, sample_list):
        # bert text input
        input_ids = sample_list.input_ids
        input_mask = sample_list.input_mask
        input_type_ids = sample_list.segment_ids
        input_ids = transform_to_batch_sequence(input_ids)
        input_type_ids = transform_to_batch_sequence(input_type_ids)
        input_mask = transform_to_batch_sequence(input_mask)

        if input_mask is None:
            input_mask = torch.ones_like(input_ids)
        if input_type_ids is None:
            input_type_ids = torch.zeros_like(input_ids)
        attention_mask = input_mask.unsqueeze(1).unsqueeze(2)

        # pretraining labels
        masked_lm_labels = getattr(sample_list, "lm_label_ids", None)
        masked_lm_labels = transform_to_batch_sequence(masked_lm_labels)
        # pretraining labels
        # is_random_next = getattr(sample_list, "is_correct", None)
        # TODO(aps): Fix later on dataset side
        is_random_next = None

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        # fp16 compatibility
        attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)
        attention_mask = (1.0 - attention_mask) * -10000.0

        text_embedding = self.word_embedding(input_ids, input_type_ids)
        text_embedding_total = self.process_text_embedding(
            text_embedding, input_mask == 0
        )

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total
        )

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        # image_embedding_total = image_embedding_total *
        # input_mask.unsqueeze(-1).float()
        # text_embedding_total = text_embedding_total *
        # input_mask.unsqueeze(-1).float()

        if self.config.combine_embeddings:
            joint_embedding = self.combine_embeddings(
                ["image", "text"], [image_embedding_total, text_embedding_total]
            )
        else:
            joint_embedding = image_embedding_total

        output_dict = {}

        pooled_output = self.pooler(joint_embedding)

        if "pretraining" in self.config.training_head_type:
            prediction_scores, seq_relationship_score = self.classifier(
                joint_embedding, pooled_output
            )
            output_dict["logits"] = prediction_scores

            if masked_lm_labels is not None:
                loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
                masked_lm_loss = loss_fct(
                    prediction_scores.contiguous().view(
                        -1, self.bert_config.vocab_size
                    ),
                    masked_lm_labels.contiguous().view(-1),
                )
                # print(seq_relationship_score.argmax(dim=1), is_random_next)
                loss_key = "{}/{}".format(
                    sample_list.dataset_name, sample_list.dataset_type
                )

                output_dict["losses"] = {}
                output_dict["losses"][loss_key + "/masked_lm_loss"] = masked_lm_loss

                if is_random_next is not None:
                    output_dict["seq_relationship_score"] = seq_relationship_score

                    next_sentence_loss = loss_fct(
                        seq_relationship_score.contiguous().view(-1, 2),
                        is_random_next.contiguous().view(-1),
                    )
                    output_dict["losses"][
                        loss_key + "/next_sentence_loss"
                    ] = next_sentence_loss
            return output_dict
        elif (
            "vqa" in self.config.training_head_type
            or self.config.training_head_type == "vizwiz"
        ):
            index_to_gather = input_mask.sum(1) - 2

            pooled_output = torch.gather(
                joint_embedding,
                1,
                index_to_gather.unsqueeze(-1)
                .unsqueeze(-1)
                .expand(index_to_gather.size(0), 1, joint_embedding.size(-1)),
            )

            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
            reshaped_logits = logits.contiguous().view(-1, self.answer_space_size)

            output_dict["scores"] = reshaped_logits
            return output_dict
        elif (
            self.config.training_head_type == "nlvr2"
            or self.config.training_head_type == "visual_entailment"
        ):
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
            output_dict["scores"] = logits
            return output_dict

        return output_dict
back to top