https://github.com/facebookresearch/pythia
Raw File
Tip revision: c33bdc2d2ac9594402e19991a28974c425c50003 authored by Vedanuj Goswami on 08 September 2021, 04:52:49 UTC
skip optimizer update when nan loss
Tip revision: c33bdc2
krisp.py
# Copyright (c) Facebook, Inc. and its affiliates.

import logging
import math
import os

import torch
from mmf.common.registry import registry
from mmf.models import BaseModel
from mmf.modules.embeddings import BertVisioLinguisticEmbeddings
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.file_io import PathManager
from mmf.utils.modeling import get_optimizer_parameters_for_bert
from mmf.utils.transform import (
    transform_to_batch_sequence,
    transform_to_batch_sequence_dim,
)
from omegaconf import OmegaConf
from torch import nn
from transformers.modeling_bert import (
    BertConfig,
    BertEncoder,
    BertLayer,
    BertModel,
    BertPooler,
    BertPredictionHeadTransform,
    BertPreTrainedModel,
)


logger = logging.getLogger(__name__)


# This model essentially wraps GraphNetworkModule and multi-modal models
@registry.register_model("krisp")
class KRISP(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.build()

    @classmethod
    def config_path(cls):
        return "configs/models/krisp/defaults.yaml"

    # Each method need to define a build method where the model's modules
    # are actually build and assigned to the model
    def build(self):
        # Get any cross-model info we need for building network
        # (like hidden sizes)
        extra_config = {}
        extra_config["vb_hid_sz"] = self.config.visual_bert.hidden_size
        extra_config["node_hid_dim"] = self.config.graph_module.node_hid_dim

        # Also pass arguments to know if it needs to feed in something
        extra_config["feed_vb_to_graph"] = self.config.feed_vb_to_graph
        extra_config["feed_q_to_graph"] = self.config.feed_q_to_graph
        extra_config["feed_mode"] = self.config.feed_mode
        extra_config["feed_graph_to_vb"] = self.config.feed_graph_to_vb
        extra_config["feed_special_node"] = self.config.feed_special_node
        extra_config["topk_ans_feed"] = self.config.topk_ans_feed
        extra_config["compress_crossmodel"] = self.config.compress_crossmodel
        extra_config["crossmodel_compress_dim"] = self.config.crossmodel_compress_dim
        extra_config["analysis_mode"] = self.config.analysis_mode
        extra_config["noback_vb"] = self.config.noback_vb_to_graph

        # If feed q, make the question module here
        if self.config.feed_q_to_graph:
            # We can just make it a BERT model really easily
            self.q_enc = BertModel.from_pretrained("bert-base-uncased")
            extra_config["q_hid_sz"] = self.q_enc.config.hidden_size

        # Import graph network module
        # Putting in try-catch to avoid adding dependencies to mmf
        try:
            from projects.krisp.graphnetwork_module import GraphNetworkModule
        except Exception:
            print(
                "Import error with KRISP dependencies. Fix dependencies if "
                + "you want to use KRISP"
            )
            raise
        # Builds the graph network module
        self.graph_module = GraphNetworkModule(self.config.graph_module, extra_config)

        # Make VisualBERT module (without the final hidden logit layer)
        self.vb_module = VisualBERTModule(self.config.visual_bert, extra_config)

        # Final hidden layer for the vb module
        self.vocab_fc = nn.Linear(
            self.vb_module.model.bert.config.hidden_size, self.config.num_labels
        )

        # There's whether to use the bilinear and then whether to add or concat features
        # These are not mutally exclusive really
        # If output combine is ptr net, make GraphPtr Net for combining outputs
        if self.config.graph_logit_mode == "mc4":
            # Bilinear network
            self.graph_ptr_net = GraphPtrNet(
                self.vb_module.model.bert.config.hidden_size,
                self.config.graph_module.node_hid_dim,
            )
        elif self.config.graph_logit_mode == "in_graph":
            # Logits is already computed
            pass
        elif self.config.graph_logit_mode == "logit_fc":
            # Compute logits from single hidden layer
            self.graph_logit_fc = nn.Linear(
                self.config.graph_module.node_hid_dim, self.config.num_labels
            )

        # Answer indices not in graph
        if self.config.output_combine == "add":
            self.missing_ans_inds = torch.LongTensor(self.config.num_labels).fill_(1)
            self.missing_ans_inds[
                self.graph_module.index_in_ans
            ] = 0  # Now any index stil set to 1 is missing from graph

    # Each model in MMF gets a dict called sample_list which contains
    # all of the necessary information returned from the image
    def forward(self, sample_list):
        # If we have different combine modes, may need to call in different order
        if self.config.feed_graph_to_vb:
            # Can't be both (would create circular dep)
            assert not self.config.feed_vb_to_graph

            # Check mode
            # Can be feed_graph_hid_to_vb, where we pass in some vector
            # rep of graph into vb or feed_top_node_to_vb which is similar,
            # but it feeds in k node hidden states
            assert self.config.feed_mode in [
                "feed_graph_hid_to_vb",
                "feed_top_node_to_vb",
            ]
            if self.config.feed_mode == "feed_graph_hid_to_vb":
                assert self.graph_module.gn.output_special_node
            else:
                raise Exception("Unknown feed mode %s" % self.config.feed_mode)

            # Forward through graph module
            graph_output = self.graph_module(sample_list)

            # Put graph_output into sample_list
            sample_list["graph_output"] = graph_output

            # Forward through vb module
            vb_hidden = self.vb_module(sample_list)

            # Get vocab logit preds
            vb_logits = self.vocab_fc(vb_hidden)

        else:
            # Check mode
            if self.config.feed_vb_to_graph:
                # Can be feed_vb_hid_to_graph where we feed final
                # vb state into graph as a node input
                # Or feed_vb_logit_to_graph where we feed vg_predicted logits into graph
                assert self.config.feed_mode in [
                    "feed_vb_hid_to_graph",
                    "feed_vb_logit_to_graph",
                ]

            # Forward through vb module
            vb_hidden = self.vb_module(sample_list)

            # Get vocab logit preds
            vb_logits = self.vocab_fc(vb_hidden)

            sample_list["vb_hidden"] = vb_hidden
            sample_list["vb_logits"] = vb_logits

            # If we feed seperate Q feats into graph
            if self.config.feed_q_to_graph:
                # Now sample_list has all the processed inputs for us
                attention_mask_q = (sample_list["input_ids"] != 0).float()
                q_enc_out = self.q_enc(
                    input_ids=sample_list["input_ids"],
                    attention_mask=attention_mask_q,
                    token_type_ids=sample_list["token_type_ids"],
                )
                sample_list["q_encoded"] = q_enc_out[1]  # Get pooled output

            # Forward through graph module
            graph_output = self.graph_module(sample_list)

        # Compute graph logits
        if self.config.graph_logit_mode == "mc4":
            # Use bilinear network
            if self.config.noback_vb_to_blinear:
                graph_logits = self.graph_ptr_net(vb_hidden.detach(), graph_output)
            else:
                graph_logits = self.graph_ptr_net(vb_hidden, graph_output)

        elif self.config.graph_logit_mode == "in_graph":
            # Logits is already computed
            graph_logits = graph_output
            assert self.config.graph_module.output_type == "graph_prediction"
        elif self.config.graph_logit_mode == "logit_fc":
            # Compute logits from single hidden layer
            graph_logits = self.graph_logit_fc(graph_output)

        # Now combine outputs
        if self.config.output_combine == "concat":
            # Output order should be alphabetical
            assert self.config.graph_module.output_order == "alpha"

            # Combine both logits
            logits = torch.cat([vb_logits, graph_logits], dim=1)
        elif self.config.output_combine == "add":
            # Output order should be ans
            assert self.config.graph_module.output_order == "ans"

            # Set invalid inds to zero here
            assert graph_logits.size(1) == vb_logits.size(1)
            graph_logits[:, self.missing_ans_inds] = 0
            logits = vb_logits + graph_logits

        # Do zerobias
        if self.config.zerobias:
            logits -= 6.58

        # For loss calculations (automatically done by MMF
        # as per the loss defined in the config),
        # we need to return a dict with "scores" key as logits
        output = {"scores": logits}

        # If we're in eval / analysis mode, add more to output
        if self.config.analysis_mode:
            output = self.graph_module.add_analysis_to_output(output)

        # MMF will automatically calculate loss
        return output


class GraphPtrNet(nn.Module):
    def __init__(self, hidden_size, graph_hidden_size):
        super().__init__()

        self.hidden_size = hidden_size
        self.graph_hidden_size = graph_hidden_size

        self.bl_w = nn.Linear(hidden_size, hidden_size)
        self.graph_w = nn.Linear(graph_hidden_size, hidden_size)

    def forward(self, bl_hidden, graph_hidden):
        # Compute Eq. 4 from Iterative Answer Prediction with
        # Pointer-Augmented Multimodal Transformers for TextVQA
        # bl_hidden is bs x hidden_size
        # graph_hidden is bs x graph_hidden_size

        # Compute BL half
        bl_hidden = self.bl_w(bl_hidden)
        assert bl_hidden.dim() == 2
        bl_hidden = bl_hidden.unsqueeze(1)

        # Compute graph hidden half
        # Assume we've already subsampled to only valid answer nodes
        graph_hidden = self.graph_w(graph_hidden)

        # Now we have bl_hidden as a bs x 1 x hid vec
        # graph_hidden as a bs x num_nodes x hid vec

        # Combine
        scores = torch.matmul(bl_hidden, graph_hidden.transpose(-1, -2))

        # Normalize
        scores = scores / math.sqrt(self.hidden_size)
        scores = scores.squeeze(1)

        # Scores is now a bs x #nodes matrix
        return scores


class VisualBERTBase(BertPreTrainedModel):
    def __init__(
        self,
        config,
        visual_embedding_dim=512,
        embedding_strategy="plain",
        bypass_transformer=False,
        output_attentions=False,
        output_hidden_states=False,
    ):
        super().__init__(config)
        self.config = config

        config.visual_embedding_dim = visual_embedding_dim
        config.embedding_strategy = embedding_strategy
        config.bypass_transformer = bypass_transformer
        config.output_attentions = output_attentions
        config.output_hidden_states = output_hidden_states

        self.embeddings = BertVisioLinguisticEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.bypass_transformer = config.bypass_transformer

        if self.bypass_transformer:
            self.additional_layer = BertLayer(config)

        self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states
        self.fixed_head_masks = [None for _ in range(len(self.encoder.layer))]
        self.init_weights()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        visual_embeddings=None,
        position_embeddings_visual=None,
        visual_embeddings_type=None,
        image_text_alignment=None,
        graph_input=None,
    ):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of
        # causal attention used in OpenAI GPT, we just need to prepare the
        # broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype
        )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        assert position_embeddings_visual is None
        embedding_output = self.embeddings(
            input_ids,
            token_type_ids,
            visual_embeddings=visual_embeddings,
            visual_embeddings_type=visual_embeddings_type,
            image_text_alignment=image_text_alignment,
        )

        if self.bypass_transformer and visual_embeddings is not None:
            assert (
                not self.output_hidden_states
            )  # Don't support this for the bypass model
            text_length = input_ids.size(1)
            text_embedding_output = embedding_output[:, :text_length, :]
            visual_part = embedding_output[:, text_length:, :]

            text_extended_attention_mask = extended_attention_mask[
                :, :, :text_length, :text_length
            ]

            encoded_layers = self.encoder(
                text_embedding_output,
                text_extended_attention_mask,
                self.fixed_head_masks,
            )
            sequence_output = encoded_layers[0]
            new_input = torch.cat((sequence_output, visual_part), dim=1)
            final_sequence_output = self.additional_layer(
                new_input, extended_attention_mask
            )
            pooled_output = self.pooler(final_sequence_output)
            return final_sequence_output, pooled_output

        else:
            # If it takes graph input(s)
            # Do forward here through its own embedding
            # Then concat to embedding_output
            # And concat onto the extended_attention_mask to inclode this too
            if graph_input is not None:
                # Concat onto embeddings
                embedding_output = torch.cat([embedding_output, graph_input], dim=1)
                graph_att_mask = torch.zeros(
                    graph_input.size(0), 1, 1, graph_input.size(1)
                ).to(extended_attention_mask.device)
                extended_attention_mask = torch.cat(
                    [extended_attention_mask, graph_att_mask], dim=3
                )

            encoded_layers = self.encoder(
                embedding_output, extended_attention_mask, self.fixed_head_masks
            )
            sequence_output = encoded_layers[0]
            pooled_output = self.pooler(sequence_output)
            attn_data_list = []

            if self.output_attentions:
                attn_data_list = encoded_layers[1:]

            return sequence_output, pooled_output, attn_data_list


class VisualBERTForClassification(nn.Module):
    def __init__(self, config, extra_config):
        super().__init__()
        self.config = config
        self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states
        self.pooler_strategy = self.config.get("pooler_strategy", "default")

        # Graph input params
        self.feed_graph_to_vb = extra_config["feed_graph_to_vb"]
        self.graph_node_hid_dim = extra_config["node_hid_dim"]
        self.graph_feed_mode = extra_config["feed_mode"]
        self.graph_topk = extra_config["topk_ans_feed"]

        # If doing graph, make a graph embedding layer
        if self.feed_graph_to_vb:
            self.graph_embedding = nn.Sequential(
                nn.Linear(self.graph_node_hid_dim, config.hidden_size),
                nn.LayerNorm(config.hidden_size, eps=1e-12),
                nn.Dropout(config.hidden_dropout_prob),  # hidden_dropout_prb
            )

        # If bert_model_name is not specified, you will need to specify
        # all of the required parameters for BERTConfig and a pretrained
        # model won't be loaded
        self.bert_model_name = self.config.get("bert_model_name", None)
        self.bert_config = BertConfig.from_dict(
            OmegaConf.to_container(self.config, resolve=True)
        )
        if self.bert_model_name is None or self.bert_model_name == "nopretrain":
            self.bert = VisualBERTBase(
                self.bert_config,
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )
        else:
            self.bert = VisualBERTBase.from_pretrained(
                self.config.bert_model_name,
                config=self.bert_config,
                cache_dir=os.path.join(
                    get_mmf_cache_dir(), "distributed_{}".format(-1)
                ),
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )

        self.training_head_type = self.config.training_head_type
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        if self.config.training_head_type == "nlvr2":
            self.bert.config.hidden_size *= 2
        self.classifier = nn.Sequential(BertPredictionHeadTransform(self.bert.config))

        self.init_weights()

    def init_weights(self):
        if self.config.random_initialize is False:
            if self.bert_model_name is None:
                # No pretrained model, init weights
                self.bert.init_weights()

            # Classifier needs to be initialized always as it is task specific
            self.classifier.apply(self.bert._init_weights)

    def forward(
        self,
        input_ids,
        input_mask,
        attention_mask=None,
        token_type_ids=None,
        visual_embeddings=None,
        position_embeddings_visual=None,
        visual_embeddings_type=None,
        image_text_alignment=None,
        masked_lm_labels=None,
        graph_input=None,
    ):
        # If we have a graph input, do the embedding first
        if self.feed_graph_to_vb:
            # Sanity check sizes
            if self.graph_feed_mode == "feed_graph_hid_to_vb":
                assert (
                    graph_input.dim() == 2
                    and graph_input.size(0) == input_ids.size(0)
                    and graph_input.size(1) == self.graph_node_hid_dim
                )
                graph_input = graph_input.unsqueeze(1)  # Add extra dim
            elif self.graph_feed_mode == "feed_top_node_to_vb":
                assert (
                    graph_input.dim() == 3
                    and graph_input.size(0) == input_ids.size(0)
                    and graph_input.size(1) == self.graph_topk
                    and graph_input.size(1) == self.graph_node_hid_dim
                )
            # Do the graph embedding
            graph_input = self.graph_embedding(graph_input)

        sequence_output, pooled_output, attention_weights = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            visual_embeddings,
            position_embeddings_visual,
            visual_embeddings_type,
            image_text_alignment,
            graph_input,
        )

        if self.training_head_type == "nlvr2":
            # 2B * H => B * 2H
            b, h = pooled_output.size()
            pooled_output = torch.cat(
                [pooled_output[: b // 2], pooled_output[b // 2 :]], dim=1
            )

        output_dict = {}
        if self.output_attentions:
            output_dict["attention_weights"] = attention_weights

        if self.output_hidden_states:
            output_dict["sequence_output"] = sequence_output
            output_dict["pooled_output"] = pooled_output

        if self.pooler_strategy == "vqa":
            # In VQA2 pooling strategy, we use representation from second last token
            index_to_gather = input_mask.sum(1) - 2
            pooled_output = torch.gather(
                sequence_output,
                1,
                index_to_gather.unsqueeze(-1)
                .unsqueeze(-1)
                .expand(index_to_gather.size(0), 1, sequence_output.size(-1)),
            )

        pooled_output = self.dropout(pooled_output)
        output = self.classifier(pooled_output).squeeze(1)
        return output


class VisualBERTModule(nn.Module):
    def __init__(self, config, extra_config=None):
        super().__init__()
        self.config = config
        if extra_config is None:
            self.extra_config = {}
        else:
            self.extra_config = extra_config
        self.build()

    def build(self):
        assert self.config.training_head_type != "pretraining"
        self.model = VisualBERTForClassification(self.config, self.extra_config)

        if self.config.special_visual_initialize:
            self.model.bert.embeddings.initialize_visual_from_pretrained()

        # Initialize from pretrained model
        if self.config.load_from_pretrained:
            # Load the raw checkpoint
            pretrained_file = self.config.pretrained_file
            with PathManager.open(pretrained_file, "rb") as f:
                ckpt = torch.load(f, map_location=lambda storage, loc: storage)
            model_ckpt = ckpt["model"]

            # Remove "model" in fron of keys
            model_ckpt_new = {}
            for key in model_ckpt:
                if "bert" not in key:
                    continue
                model_ckpt_new[key.split("model.")[1]] = model_ckpt[key]
            model_ckpt = model_ckpt_new

            # Load the checkpoint
            incompatible_keys = self.model.load_state_dict(model_ckpt, strict=False)

            # Print any missing / wrong keys for debug
            if len(incompatible_keys.missing_keys) != 0:
                logger.warning(
                    f"Missing keys {incompatible_keys.missing_keys} in the"
                    + " checkpoint.\n"
                    + "If this is not your checkpoint, please open up an "
                    + "issue on MMF GitHub. \n"
                    + f"Unexpected keys if any: {incompatible_keys.unexpected_keys}"
                )
            if len(incompatible_keys.unexpected_keys) != 0:
                logger.warning(
                    "Unexpected keys in state dict: "
                    + f"{incompatible_keys.unexpected_keys} \n"
                    + "This is usually not a problem with pretrained models, but "
                    + "if this is your own model, please double check. \n"
                    + "If you think this is an issue, please open up a "
                    + "bug at MMF GitHub."
                )

        if getattr(self.config, "freeze_base", False):
            for p in self.model.bert.parameters():
                p.requires_grad = False

        # Graph input params
        self.feed_graph_to_vb = self.extra_config["feed_graph_to_vb"]
        self.graph_node_hid_dim = self.extra_config["node_hid_dim"]
        self.graph_feed_mode = self.extra_config["feed_mode"]

        # Not implemented for this model
        if self.feed_graph_to_vb and self.extra_config["compress_crossmodel"]:
            assert False

    def flatten(self, sample_list, to_be_flattened=None, to_be_flattened_dim=None):
        if to_be_flattened is None:
            to_be_flattened = {}
        if to_be_flattened_dim is None:
            to_be_flattened_dim = {}
        for key in to_be_flattened:
            # Make sure these keys are present or otherwise set these keys to None
            sample_list[key] = getattr(sample_list, key, None)
            sample_list[key] = transform_to_batch_sequence(sample_list[key])
        for key in to_be_flattened_dim:
            sample_list[key] = getattr(sample_list, key, None)
            sample_list[key] = transform_to_batch_sequence_dim(sample_list[key])

        if sample_list.visual_embeddings_type is None:
            if sample_list.image_mask is not None:
                sample_list.visual_embeddings_type = torch.zeros_like(
                    sample_list.image_mask
                )

        if sample_list.image_mask is not None:
            attention_mask = torch.cat(
                (sample_list.input_mask, sample_list.image_mask), dim=-1
            )
            if sample_list.masked_lm_labels is not None:
                assert sample_list.masked_lm_labels.size(
                    -1
                ) == sample_list.input_mask.size(-1)
                new_lm_labels = torch.ones_like(attention_mask) * -1
                size_masked_lm_labels = sample_list.masked_lm_labels.size()
                assert len(size_masked_lm_labels) == 2
                new_lm_labels[
                    : size_masked_lm_labels[0], : size_masked_lm_labels[1]
                ] = sample_list.masked_lm_labels
                sample_list.masked_lm_labels = new_lm_labels
        else:
            attention_mask = sample_list.input_mask

        sample_list.attention_mask = attention_mask

        return sample_list

    def get_optimizer_parameters(self, config):
        return get_optimizer_parameters_for_bert(self.model, config)

    def flatten_for_bert(self, sample_list, **kwargs):
        to_be_flattened = [
            "input_ids",
            "token_type_ids",
            "input_mask",
            "image_mask",
            "masked_lm_labels",
            # "position_embeddings_visual",
            # "visual_embeddings_type",
        ]
        to_be_flattened_dim = ["visual_embeddings"]  # "image_text_alignment",

        # We want to convert everything into: batch x sequence_length x (dim).
        flattened = self.flatten(sample_list, to_be_flattened, to_be_flattened_dim)
        return flattened

    def update_sample_list_based_on_head(self, sample_list):
        bert_input_ids = sample_list.input_ids
        bert_input_mask = sample_list.input_mask
        bert_input_type_ids = sample_list.segment_ids

        if self.config.training_head_type == "nlvr2":
            bert_input_ids = torch.cat([bert_input_ids, bert_input_ids])
            bert_input_mask = torch.cat([bert_input_mask, bert_input_mask])
            bert_input_type_ids = torch.cat([bert_input_type_ids, bert_input_type_ids])

            # image input
            img0 = getattr(sample_list, "img0", {})
            image_info = getattr(img0, "image_info_0", {})
            image_dim_variable_0 = getattr(image_info, "max_features", None)
            image_feat_variable_0 = getattr(img0, "image_feature_0", None)

            img1 = getattr(sample_list, "img1", {})
            image_info = getattr(img1, "image_info_0", {})
            image_dim_variable_1 = getattr(image_info, "max_features", None)
            image_feat_variable_1 = getattr(img1, "image_feature_0", None)

            image_feat_variable = torch.cat(
                [image_feat_variable_0, image_feat_variable_1]
            )
            image_dim_variable = torch.cat([image_dim_variable_0, image_dim_variable_1])
        else:
            image_info = getattr(sample_list, "image_info_0", {})
            image_dim_variable = getattr(image_info, "max_features", None)
            image_feat_variable = getattr(sample_list, "image_feature_0", None)

        sample_list.visual_embeddings = image_feat_variable
        sample_list.image_dim = image_dim_variable
        sample_list.input_ids = bert_input_ids
        sample_list.input_mask = bert_input_mask
        sample_list.token_type_ids = bert_input_type_ids
        return sample_list

    def add_custom_params(self, sample_list):
        visual_embeddings = getattr(sample_list, "visual_embeddings", None)
        image_dim = getattr(sample_list, "image_dim", None)
        # pretraining labels
        sample_list.masked_lm_labels = getattr(sample_list, "lm_label_ids", None)
        # image_feat_variable = batch x ( num_choice x ) image_feature_length x dim
        # Prepare Mask
        if visual_embeddings is not None and image_dim is not None:
            image_mask = torch.arange(
                visual_embeddings.size(-2), device=visual_embeddings.device
            ).expand(*visual_embeddings.size()[:-1])
            if len(image_dim.size()) < len(image_mask.size()):
                image_dim = image_dim.unsqueeze(-1)
                assert len(image_dim.size()) == len(image_mask.size())
            image_mask = image_mask < image_dim
            sample_list.image_mask = image_mask.long()
        else:
            sample_list.image_mask = None

        sample_list.position_embeddings_visual = None
        sample_list.visual_embeddings_type = None
        sample_list.image_text_alignment = None
        return sample_list

    # Backward compatibility for code from original VisualBERT
    @classmethod
    def format_state_key(cls, key):
        return (
            key.replace("bert.bert", "model.bert")
            .replace("bert.cls", "model.cls")
            .replace("bert.classifier", "model.classifier")
        )

    def forward(self, sample_list):
        sample_list = self.update_sample_list_based_on_head(sample_list)
        sample_list = self.add_custom_params(sample_list)
        sample_list = self.flatten_for_bert(sample_list)

        if self.feed_graph_to_vb:
            if self.graph_feed_mode == "feed_graph_hid_to_vb":
                assert "graph_special_node_out" in sample_list
                graph_input = sample_list["graph_special_node_out"]
            else:
                assert False
        else:
            graph_input = None

        output = self.model(
            sample_list.input_ids,
            sample_list.input_mask,
            sample_list.attention_mask,
            sample_list.token_type_ids,
            sample_list.visual_embeddings,
            sample_list.position_embeddings_visual,
            sample_list.visual_embeddings_type,
            sample_list.image_text_alignment,
            sample_list.masked_lm_labels,
            graph_input,
        )

        return output
back to top