https://github.com/facebookresearch/pythia
Tip revision: 2192818943499b83c2b0702a01384ba5fe3b5f61 authored by sash on 30 June 2021, 18:34:36 UTC
[fix]: clean up lightning test code
[fix]: clean up lightning test code
Tip revision: 2192818
mmbt.py
# Copyright (c) Facebook, Inc. and its affiliates.
# MMBTModel, ModalEmbeddings is copied from [1]
# as we have internal dependency on transformers v2.3.
# These will be removed when we upgrade to package v2.5+.
# [1]: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_mmbt.py # noqa
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, Optional, Union
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.models.interfaces.mmbt import MMBTGridHMInterface
from mmf.modules.encoders import (
EncoderFactory,
ImageEncoderFactory,
ImageEncoderTypes,
MultiModalEncoderBase,
ResNet152ImageEncoder,
TextEncoderFactory,
TextEncoderTypes,
TransformerEncoder,
)
from mmf.modules.hf_layers import replace_with_jit
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.modeling import get_optimizer_parameters_for_bert
from omegaconf import II, DictConfig, OmegaConf
from torch import Tensor, nn
from transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform
# TODO: Remove after transformers package upgrade to 2.5
class MMBTConfig:
"""Configuration class to store the configuration of a `MMBT Model`.
Args:
config (:obj:`~transformers.PreTrainedConfig`):
Config of the underlying Transformer models. Its values are
copied over to use a single config.
num_labels (:obj:`int` or :obj:`None`, optional, defaults to `None`):
Size of final Linear layer for classification.
modal_hidden_size (:obj:`int`, optional, defautls to 2048):
Embedding dimension of the non-text modality encoder.
"""
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
self.__dict__ = config.__dict__
self.modal_hidden_size = modal_hidden_size
if num_labels:
self.num_labels = num_labels
# TODO: Remove after transformers package upgrade to 2.5
class ModalEmbeddings(nn.Module):
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding.
"""
def __init__(self, config, encoder, embeddings):
super().__init__()
self.config = config
self.encoder = encoder
self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
self.position_embeddings = embeddings.position_embeddings
self.token_type_embeddings = embeddings.token_type_embeddings
self.word_embeddings = embeddings.word_embeddings
self.LayerNorm = embeddings.LayerNorm
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
def forward(
self,
input_modal: Tensor,
start_token: Optional[Tensor] = None,
end_token: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
):
token_embeddings = self.proj_embeddings(self.encoder(input_modal))
seq_length = token_embeddings.size(1)
if start_token is not None:
start_token_embeds = self.word_embeddings(start_token)
seq_length += 1
token_embeddings = torch.cat(
[start_token_embeds.unsqueeze(1), token_embeddings], dim=1
)
if end_token is not None:
end_token_embeds = self.word_embeddings(end_token)
seq_length += 1
token_embeddings = torch.cat(
[token_embeddings, end_token_embeds.unsqueeze(1)], dim=1
)
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_modal.device
)
position_ids = position_ids.unsqueeze(0).expand(
input_modal.size(0), seq_length
)
if token_type_ids is None:
token_type_ids = torch.zeros(
(input_modal.size(0), seq_length),
dtype=torch.long,
device=input_modal.device,
)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = token_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# TODO: Remove after transformers package upgrade to 2.5
class MMBTModel(nn.Module):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration
(config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape
``(batch_size, sequence_length, hidden_size)``. Sequence of
hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape
``(batch_size, hidden_size)``. Last layer hidden-state of the
first token of the sequence (classification token) further processed
by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction
(classification) objective during Bert pretraining. This output
is usually *not* a good summary of the semantic content of the
input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when
``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer +
the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the
initial embedding outputs.
**attentions**: (`optional`, returned when
``config.output_attentions=True``) list of ``torch.FloatTensor``
(one for each layer) of shape ``(batch_size, num_heads,
sequence_length, sequence_length)``: Attentions weights after
the attention softmax, used to compute the weighted average in the
self-attention heads.
Examples::
# For example purposes. Not runnable.
transformer = BertModel.from_pretrained('bert-base-uncased')
encoder = ImageEncoder(args)
mmbt = MMBTModel(config, transformer, encoder)
"""
def __init__(self, config, transformer, encoder):
super().__init__()
self.is_decoder = config.is_decoder
self.num_hidden_layers = config.num_hidden_layers
self.transformer = transformer
self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
def forward(
self,
input_modal: Tensor,
input_ids: Tensor,
modal_start_tokens: Optional[Tensor] = None,
modal_end_tokens: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
modal_token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
modal_position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
encoder_hidden_states: Optional[Tensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_txt_shape = input_ids.size()
elif inputs_embeds is not None:
input_txt_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
modal_embeddings = self.modal_encoder(
input_modal,
start_token=modal_start_tokens,
end_token=modal_end_tokens,
position_ids=modal_position_ids,
token_type_ids=modal_token_type_ids,
)
input_modal_shape = modal_embeddings.size()[:-1]
if token_type_ids is None:
token_type_ids = torch.ones(
input_txt_shape, dtype=torch.long, device=device
)
txt_embeddings = self.transformer.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
input_shape = embedding_output.size()[:-1]
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
else:
attention_mask = torch.cat(
[
torch.ones(input_modal_shape, device=device, dtype=torch.long),
attention_mask,
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
else:
encoder_attention_mask = torch.cat(
[torch.ones(input_modal_shape, device=device), encoder_attention_mask],
dim=1,
)
# We can provide a self-attention mask of dimensions
# [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the
# padding mask
# - if the model is an encoder, make the mask broadcastable to
# [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2:
if self.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# Python builtin next is currently not supported in Torchscript
if not torch.jit.is_scripting():
attention_mask = attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to
# [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask.dim() == 3:
encoder_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
# Python builtin next is currently not supported in Torchscript
if not torch.jit.is_scripting():
encoder_attention_mask = encoder_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
encoder_outputs = self.transformer.encoder(
embedding_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = self.transformer.pooler(sequence_output)
outputs = (
sequence_output,
pooled_output,
encoder_outputs[1:],
) # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
class MMBTBase(MultiModalEncoderBase):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
# Replace transformer layers with scriptable JIT layers
replace_with_jit()
def build(self):
encoders = self._build_encoders(self.config)
text_encoder, modal_encoder = encoders[0], encoders[1]
self._encoder_config = text_encoder.config
self._mmbt_config = MMBTConfig(
self._encoder_config,
num_labels=self.config.num_labels,
modal_hidden_size=self.config.modal_hidden_size,
)
self.use_modal_start_token = self.config.use_modal_start_token
self.use_modal_end_token = self.config.use_modal_end_token
self.num_max_segment = self.config.text_encoder.params.get("num_segments", 2)
self.mmbt = MMBTModel(self._mmbt_config, text_encoder, modal_encoder)
def extract_modal_end_token(self, sample_list: Dict[str, Tensor]):
# compute the position of the last non-masked token, which is <sep>
gather_index = sample_list["input_mask"].sum(1, keepdim=True) - 1
modal_end_token = (
torch.gather(sample_list["input_ids"], 1, gather_index)
.squeeze(1)
.clone()
.detach()
)
batch_size = sample_list["input_ids"].size(0)
device = sample_list["input_ids"].device
# remove start_token in input_ids
sample_list["input_ids"] = torch.cat(
[sample_list["input_ids"][:, 1:], sample_list["input_ids"][:, -1:]], dim=1
)
# update input_mask
sample_list["input_mask"] = torch.cat(
[
sample_list["input_mask"][:, 1:],
torch.zeros([batch_size, 1], dtype=torch.long, device=device),
],
dim=1,
)
return modal_end_token
def forward(self, sample_list: Dict[str, Tensor]):
if self._is_direct_features_input:
if "input_modal" in sample_list:
input_modal = sample_list["input_modal"]
else:
input_modal = sample_list["image_feature_0"]
else:
input_modal = sample_list["image"]
modal_start_token: Optional[Tensor] = None
if self.use_modal_start_token:
modal_start_token = sample_list["input_ids"][:, 0].clone().detach()
modal_end_token: Optional[Tensor] = None
if self.use_modal_end_token:
modal_end_token = self.extract_modal_end_token(sample_list)
if "modal_token_type_ids" in sample_list:
modal_token_type_ids = sample_list["modal_token_type_ids"]
else:
token_value = 0
segment_ids = sample_list["segment_ids"]
max_id = segment_ids.max()
min_id = segment_ids.min()
# Case of only one segment
if max_id == min_id:
# If max_id is greater than 0, that means text is at 0 segment
# which means modal will be at 1
# In other case, it will be zero, which it already is
# NOTE: We compare with tensor here due to TorchScript compliance
if max_id == torch.tensor(0, dtype=max_id.dtype):
token_value = 1
else:
max_segment = self.num_max_segment - 1
# If max id is not equal to max_segment, it means
# text segments start from 0 which means modal will
# be last, otherwise, it is 0, which it already is
if max_id != torch.tensor(max_segment, dtype=max_id.dtype):
token_value = max_segment
modal_token_type_ids = torch.full(
(input_modal.size(0), 1),
fill_value=token_value,
dtype=torch.long,
device=input_modal.device,
)
# In case of XRAY, there might be only two dims
if input_modal.dim() == 2:
input_modal = input_modal.unsqueeze(dim=1)
# See details of inputs at
# https://github.com/huggingface/transformers/blob/1789c7/src/transformers/modeling_mmbt.py#L101 # noqa
output = self.mmbt(
input_modal,
input_ids=sample_list["input_ids"],
modal_start_tokens=modal_start_token,
modal_end_tokens=modal_end_token,
attention_mask=sample_list["input_mask"],
token_type_ids=sample_list["segment_ids"],
modal_token_type_ids=modal_token_type_ids,
position_ids=None,
modal_position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
)
return output
class MMBTForPreTraining(nn.Module):
def __init__(self, config, *args, **kwargs):
super().__init__()
self.config = config
self.bert = MMBTBase(config, *args, **kwargs)
self.encoder_config = self.bert.encoder_config
# TODO : Switch to AutoModelForPreTraining after transformers
# package upgrade to 2.5
pretraining_module = BertForPreTraining.from_pretrained(
self.config.bert_model_name,
config=self.encoder_config,
cache_dir=os.path.join(get_mmf_cache_dir(), "distributed_{}".format(-1)),
)
self.cls = deepcopy(pretraining_module.cls)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we
are cloning them instead.
"""
if hasattr(self, "cls"):
self.bert.mmbt.transformer._tie_or_clone_weights(
self.cls.predictions.decoder,
self.bert.mmbt.transformer.embeddings.word_embeddings,
)
def forward(self, sample_list):
module_output = self.bert(sample_list)
sequence_output, pooled_output = module_output[0], module_output[1]
prediction_scores, seq_relationship_score = self.cls(
sequence_output, pooled_output
)
output = {}
if (
self.encoder_config.output_hidden_states
or self.encoder_config.output_attentions
):
output["extras"] = module_output[2:]
loss_key = f"{sample_list.dataset_name}/{sample_list.dataset_type}"
if "lm_label_ids" in sample_list and sample_list.lm_label_ids is not None:
output["logits"] = prediction_scores
lm_label_ids = sample_list.lm_label_ids
# Only take last scores which are text's scores and ignore image scores
text_scores = (
prediction_scores[:, -(lm_label_ids.size(1)) :]
.contiguous()
.view(-1, self.encoder_config.vocab_size)
)
masked_lm_loss = self.loss_fct(
text_scores, sample_list.lm_label_ids.contiguous().view(-1)
)
output["losses"] = {}
output["losses"][f"{loss_key}/masked_lm_loss"] = masked_lm_loss
# Add alignment loss if present
if (
"image_text_alignment" in sample_list
and sample_list.image_text_alignment is not None
):
output["seq_relationship_logits"] = seq_relationship_score
alignment_loss = self.loss_fct(
seq_relationship_score.contiguous().view(-1),
sample_list.image_text_alignment.contiguous().view(-1),
)
output["losses"][f"{loss_key}/alignment_loss"] = alignment_loss
return output
class MMBTForClassification(nn.Module):
def __init__(self, config, *args, **kwargs):
super().__init__()
self.config = config
self.bert = MMBTBase(config, *args, **kwargs)
self.encoder_config = self.bert.encoder_config
self.num_labels = self.config.num_labels
self.output_hidden_states = self.encoder_config.output_hidden_states
self.output_attentions = self.encoder_config.output_attentions
self.fused_feature_only = self.config.get("fused_feature_only", False)
self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
self.classifier = nn.Sequential(
BertPredictionHeadTransform(self.encoder_config),
nn.Linear(self.encoder_config.hidden_size, self.config.num_labels),
)
def forward(self, sample_list: Dict[str, Tensor]):
module_output = self.bert(sample_list)
pooled_output = module_output[1]
output = {}
if not torch.jit.is_scripting():
if self.output_hidden_states or self.output_attentions:
output["extras"] = module_output[2:]
else:
assert not (
self.output_hidden_states or self.output_attentions
), "output_attentions or output_hidden_states not supported in script mode"
pooled_output = self.dropout(pooled_output)
if self.fused_feature_only:
output["fused_feature"] = self.classifier[0](pooled_output)
return output
logits = self.classifier(pooled_output)
reshaped_logits = logits.contiguous().view(-1, self.num_labels)
output["scores"] = reshaped_logits
return output
@registry.register_model("mmbt")
class MMBT(BaseModel):
@dataclass
class Config(BaseModel.Config):
model: str = "mmbt"
# classification or pretraining
training_head_type: str = "pretraining"
bert_model_name: str = "bert-base-uncased"
direct_features_input: bool = False
freeze_text: bool = False
freeze_modal: bool = False
freeze_complete_base: bool = False
finetune_lr_multiplier: float = 1
# Dimension of the embedding finally returned by the modal encoder
modal_hidden_size: int = 2048
text_hidden_size: int = 768
num_labels: int = 2
# This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config(
type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config()
)
text_encoder: EncoderFactory.Config = TextEncoderFactory.Config(
type=TextEncoderTypes.transformer,
params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")),
)
use_modal_start_token: bool = True
use_modal_end_token: bool = True
fused_feature_only: bool = False
output_dim: int = 768
def __init__(self, config: Union[DictConfig, Config], *args, **kwargs):
super().__init__(config)
def build(self):
if self.config.training_head_type == "pretraining":
self.model = MMBTForPreTraining(self.config)
else:
self.model = MMBTForClassification(self.config)
if self.config.freeze_complete_base or self.config.freeze_text:
for p in self.model.bert.mmbt.transformer.parameters():
p.requires_grad = False
if self.config.freeze_complete_base or self.config.freeze_modal:
for p in self.model.bert.mmbt.modal_encoder.parameters():
p.requires_grad = False
# Backward compatibility for code from older mmbt
@classmethod
def format_state_key(cls, key):
return (
key.replace("base.bert", "model.bert")
.replace("base.cls", "model.cls")
.replace("base.classifier", "model.classifier")
)
@classmethod
def from_pretrained(cls, model_name, *args, **kwargs):
model = super().from_pretrained(model_name, *args, **kwargs)
config = load_pretrained_model(model_name)["full_config"]
OmegaConf.set_struct(config, True)
if model_name == "mmbt.hateful_memes.images" or kwargs.get("interface"):
return MMBTGridHMInterface(model, config)
return model
@classmethod
def config_path(cls):
return "configs/models/mmbt/pretrain.yaml"
def forward(self, sample_list: Dict[str, Tensor]):
return self.model(sample_list)
def get_optimizer_parameters(self, config):
return get_optimizer_parameters_for_bert(self.model, config)