https://github.com/facebookresearch/pythia
Tip revision: c33bdc2d2ac9594402e19991a28974c425c50003 authored by Vedanuj Goswami on 08 September 2021, 04:52:49 UTC
skip optimizer update when nan loss
skip optimizer update when nan loss
Tip revision: c33bdc2
lxmert.py
# Copyright 2019 project LXMERT.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from mmf.common.registry import registry
from mmf.models import BaseModel
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.modeling import get_optimizer_parameters_for_bert
from omegaconf import OmegaConf
from torch import nn
from torch.nn import CrossEntropyLoss, SmoothL1Loss
from transformers.modeling_bert import (
ACT2FN,
BertAttention,
BertConfig,
BertEmbeddings,
BertIntermediate,
BertLayer,
BertOutput,
BertPooler,
BertPredictionHeadTransform,
BertPreTrainedModel,
BertSelfAttention,
BertSelfOutput,
)
class GeLU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return ACT2FN["gelu"](x)
class BertCrossattLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.att = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
output = self.att(
input_tensor,
encoder_hidden_states=ctx_tensor,
encoder_attention_mask=ctx_att_mask,
)[0]
attention_output = self.output(output, input_tensor)
return attention_output
class BertClassificationHead(nn.Module):
def __init__(self, num_labels, hid_dim, training_head_type):
super().__init__()
if training_head_type == "nlvr2":
in_dim = hid_dim * 2
out_dim = 2
else:
in_dim = hid_dim
out_dim = num_labels
self.logit_fc = nn.Sequential(
nn.Linear(in_dim, hid_dim * 2),
GeLU(),
nn.LayerNorm(hid_dim * 2, eps=1e-12),
nn.Linear(hid_dim * 2, out_dim),
)
def forward(self, x):
logit = self.logit_fc(x)
return logit
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(
bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False,
)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertVisualAnswerHead(nn.Module):
def __init__(self, config, num_labels):
super().__init__()
hid_dim = config.hidden_size
if config.training_head_type == "nlvr2":
in_dim = hid_dim * 2
out_dim = 2
else:
in_dim = hid_dim
out_dim = config.num_labels
add_gqa = isinstance(num_labels, list)
if add_gqa:
self.logit_gqa = nn.Sequential(
nn.Linear(in_dim, hid_dim * 2),
GeLU(),
nn.LayerNorm(hid_dim * 2, eps=1e-12),
nn.Linear(hid_dim * 2, num_labels[1]),
)
out_dim = num_labels[0]
self.logit_fc = nn.Sequential(
nn.Linear(in_dim, hid_dim * 2),
GeLU(),
nn.LayerNorm(hid_dim * 2, eps=1e-12),
nn.Linear(hid_dim * 2, out_dim),
)
def forward(self, hidden_states, name=None):
if name is None or "gqa" not in name:
return self.logit_fc(hidden_states)
else:
return self.logit_gqa(hidden_states)
class BertVisualObjHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
self.visual_losses = config.visual_losses
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder_dict = nn.ModuleDict(
{
key: nn.Linear(config.hidden_size, config.visual_loss_config[key][0])
for key in self.visual_losses
}
)
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
output = {}
for key in self.visual_losses:
output[key] = self.decoder_dict[key](hidden_states)
return output
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super().__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class VisualFeatEncoder(nn.Module):
def __init__(self, config):
super().__init__()
feat_dim = config.visual_feat_dim
pos_dim = config.visual_pos_dim
# Object feature encoding
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
# Box position encoding
self.box_fc = nn.Linear(pos_dim, config.hidden_size)
self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, visn_input):
feats, boxes = visn_input
x = self.visn_fc(feats)
x = self.visn_layer_norm(x)
if boxes is not None:
y = self.box_fc(boxes)
y = self.box_layer_norm(y)
output = (x + y) / 2
else:
output = x
output = self.dropout(output)
return output
class LXMERTXLayer(nn.Module):
def __init__(self, config):
super().__init__()
# The cross-attention Layer
self.visual_attention = BertCrossattLayer(config)
# Self-attention Layers
self.lang_self_att = BertAttention(config)
self.visn_self_att = BertAttention(config)
# Intermediate and Output Layers (FFNs)
self.lang_inter = BertIntermediate(config)
self.lang_output = BertOutput(config)
self.visn_inter = BertIntermediate(config)
self.visn_output = BertOutput(config)
def cross_att(
self, lang_input, lang_attention_mask, visn_input, visn_attention_mask
):
# Cross Attention
lang_att_output = self.visual_attention(
lang_input, visn_input, ctx_att_mask=visn_attention_mask
)
visn_att_output = self.visual_attention(
visn_input, lang_input, ctx_att_mask=lang_attention_mask
)
return lang_att_output, visn_att_output
def self_att(
self, lang_input, lang_attention_mask, visn_input, visn_attention_mask
):
# Self Attention
lang_att_output = self.lang_self_att(lang_input, lang_attention_mask)[0]
visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)[0]
return lang_att_output, visn_att_output
def output_fc(self, lang_input, visn_input):
# FC layers
lang_inter_output = self.lang_inter(lang_input)
visn_inter_output = self.visn_inter(visn_input)
# Layer output
lang_output = self.lang_output(lang_inter_output, lang_input)
visn_output = self.visn_output(visn_inter_output, visn_input)
return lang_output, visn_output
def forward(self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask):
lang_att_output = lang_feats
visn_att_output = visn_feats
lang_att_output, visn_att_output = self.cross_att(
lang_att_output, lang_attention_mask, visn_att_output, visn_attention_mask
)
lang_att_output, visn_att_output = self.self_att(
lang_att_output, lang_attention_mask, visn_att_output, visn_attention_mask
)
lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output)
return lang_output, visn_output
class LXMERTEncoder(nn.Module):
def __init__(self, config):
super().__init__()
# Obj-level image embedding layer
self.visn_fc = VisualFeatEncoder(config)
# Number of layers
self.num_l_layers = config.l_layers
self.num_x_layers = config.x_layers
self.num_r_layers = config.r_layers
self.layer = nn.ModuleList(
[BertLayer(config) for _ in range(self.num_l_layers)]
)
self.x_layers = nn.ModuleList(
[LXMERTXLayer(config) for _ in range(self.num_x_layers)]
)
self.r_layers = nn.ModuleList(
[BertLayer(config) for _ in range(self.num_r_layers)]
)
def forward(
self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask=None
):
# Run visual embedding layer
# Note: Word embedding layer was executed outside this module.
# Keep this design to allow loading BERT weights.
visn_feats = self.visn_fc(visn_feats)
# Run language layers
for layer_module in self.layer:
lang_feats = layer_module(lang_feats, lang_attention_mask)[0]
# Run relational layers
for layer_module in self.r_layers:
visn_feats = layer_module(visn_feats, visn_attention_mask)[0]
# Run cross-modality layers
for layer_module in self.x_layers:
lang_feats, visn_feats = layer_module(
lang_feats, lang_attention_mask, visn_feats, visn_attention_mask
)
return lang_feats, visn_feats
class LXMERTBase(BertPreTrainedModel):
"""LXMERT Model."""
def __init__(self, config):
super().__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = LXMERTEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def forward(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
visual_feats=None,
visual_loc=None,
visual_attention_mask=None,
output_all_attention_masks=False,
output_all_encoded_layers=False,
):
if output_all_encoded_layers:
raise NotImplementedError
if output_all_attention_masks:
raise NotImplementedError
visual_feats = (visual_feats, visual_loc)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the
# triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Process the visual attention mask
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask.unsqueeze(
1
).unsqueeze(2)
extended_visual_attention_mask = extended_visual_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
extended_visual_attention_mask = (
1.0 - extended_visual_attention_mask
) * -10000.0
else:
extended_visual_attention_mask = None
# Positional Word Embeddings
embedding_output = self.embeddings(input_ids, token_type_ids)
# Run LXMERT backbone
lang_feats, visn_feats = self.encoder(
embedding_output,
extended_attention_mask,
visn_feats=visual_feats,
visn_attention_mask=extended_visual_attention_mask,
)
pooled_output = self.pooler(lang_feats)
return (lang_feats, visn_feats), pooled_output
class LXMERTForPretraining(nn.Module):
def __init__(self, config):
super().__init__()
# Configuration
self.config = config
# LXMERT backbone
self.bert = LXMERTBase.from_pretrained(
self.config.bert_model_name,
config=BertConfig.from_dict(
OmegaConf.to_container(self.config, resolve=True)
),
cache_dir=os.path.join(get_mmf_cache_dir(), "distributed_{}".format(-1)),
)
self.num_labels = config.num_labels
self.gqa_labels = config.gqa_labels
self.task_mask_lm = config.task_mask_lm
self.task_obj_predict = config.task_obj_predict
self.task_matched = config.task_matched
self.task_qa = config.task_qa
self.visual_losses = config.visual_losses
self.visual_loss_config = config.visual_loss_config
# Pre-training heads
self.cls = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight
)
if self.task_obj_predict:
self.obj_predict_head = BertVisualObjHead(config)
if self.task_qa:
self.answer_head = BertVisualAnswerHead(
config, [self.num_labels, self.gqa_labels]
)
# loss functions
self.loss_fcts = {
"l2": SmoothL1Loss(reduction="none"),
"ce": CrossEntropyLoss(ignore_index=-1, reduction="none"),
"ce_lang": CrossEntropyLoss(ignore_index=-1),
}
def init_weights(self):
if self.config.random_initialize is False:
if self.config.bert_model_name is None:
# No pretrained model, init weights
self.bert.init_weights()
self.cls.apply(self.bert._init_weights)
self.tie_weights()
def tie_weights(self):
"""Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning
them instead.
"""
self._tie_or_clone_weights(
self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
)
def forward(
self,
input_ids, # tokens
token_type_ids=None,
attention_mask=None,
visual_feats=None,
visual_pos=None,
visual_attention_mask=None,
masked_lm_labels=None,
masked_image_labels=None,
obj_labels=None,
matched_label=None, #
ans=None, # qa answer
num_features=None, # max num of objects
name=None,
output_all_attention_masks=False,
output_all_encoded_layers=False,
):
(lang_output, visn_output), pooled_output = self.bert(
input_ids,
token_type_ids,
attention_mask,
visual_feats,
visual_pos,
visual_attention_mask,
output_all_attention_masks,
output_all_encoded_layers,
)
lang_prediction_scores, cross_relationship_score = self.cls(
lang_output, pooled_output
)
## KEEP TRACK OF OUTPUTS HERE
output = {}
if output_all_attention_masks:
raise NotImplementedError
if ans is not None and self.task_qa:
answer_score = self.answer_head(pooled_output, name)
if name is None or "gqa" not in name:
num_labels = self.config.num_labels
else:
num_labels = self.config.gqa_labels
answer_loss = self.loss_fcts["ce_lang"](
answer_score.view(-1, num_labels), ans.argmax(-1)
)
output["answer_loss"] = answer_loss
if masked_lm_labels is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce_lang"](
lang_prediction_scores.view(-1, lang_prediction_scores.size(-1)),
masked_lm_labels.view(-1),
)
output["masked_lm_loss"] = masked_lm_loss
if matched_label is not None and self.task_matched:
matched_label = matched_label.to(cross_relationship_score).long()
matched_loss = self.loss_fcts["ce_lang"](
cross_relationship_score.view(-1, 2), matched_label
)
output["matched_loss"] = matched_loss
if obj_labels is not None and self.task_obj_predict:
total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visn_output)
for key in self.visual_losses:
visn_prediction_scores = visn_prediction_scores_dict[key]
(
output_dim,
loss_fct_name,
label_shape,
weight,
) = self.visual_loss_config[key]
if key == "attr":
continue
elif key == "obj":
temp_obj_labels_dict = obj_labels.max(-1)
mask_conf = temp_obj_labels_dict.values
visn_loss = self.loss_fcts[loss_fct_name](
visn_prediction_scores.view(-1, output_dim),
temp_obj_labels_dict.indices.view(-1),
)
elif key == "feat":
if type(masked_image_labels) is None:
continue
mask_conf = (masked_image_labels == 1).float()
visn_loss = self.loss_fcts[loss_fct_name](
visn_prediction_scores.view(-1, output_dim),
visual_feats.view(-1, output_dim),
)
if visn_loss.dim() > 1: # Regression Losses
visn_loss = visn_loss.mean(1)
visn_loss = (visn_loss * mask_conf.view(-1)).mean() * weight
total_visn_loss += visn_loss
output["visn_loss"] = total_visn_loss
return output
class LXMERTForClassification(nn.Module):
def __init__(self, config, mode="lxr"):
super().__init__()
self.config = config
self.num_labels = config.num_labels
self.gqa_labels = config.gqa_labels
self.mode = config.mode
self.bert = LXMERTBase.from_pretrained(
self.config.bert_model_name,
config=BertConfig.from_dict(
OmegaConf.to_container(self.config, resolve=True)
),
cache_dir=os.path.join(get_mmf_cache_dir(), "distributed_{}".format(-1)),
)
self.classifier = BertVisualAnswerHead(
config, [self.num_labels, self.gqa_labels]
)
self.init_weights()
def init_weights(self):
if self.config.random_initialize is False:
if self.config.bert_model_name is None:
# No pretrained model, init weights
self.bert.init_weights()
# Classifier needs to be initialized always as it is task specific
self.classifier.apply(self.bert._init_weights)
def forward(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
visual_feats=None,
visual_pos=None,
visual_attention_mask=None,
masked_lm_labels=None,
obj_labels=None, # is img_labels in vilbert
matched_label=None, # next_sent_label in VilBERT
ans=None,
max_features=None,
output_all_attention_masks=False,
output_all_encoded_layers=False,
):
(lang_output, visn_output), pooled_output = self.bert(
input_ids,
token_type_ids,
attention_mask,
visual_feats,
visual_pos,
visual_attention_mask,
output_all_encoded_layers,
output_all_attention_masks,
)
output = {}
if output_all_attention_masks:
raise NotImplementedError
if self.config.training_head_type == "nlvr2":
pooled_output = pooled_output.view(-1, pooled_output.size(1) * 2)
logits = self.classifier(pooled_output)
reshaped_logits = logits.contiguous().view(-1, self.config.num_labels)
output["scores"] = reshaped_logits
return output
@registry.register_model("lxmert")
class LXMERT(BaseModel):
def __init__(self, config):
super().__init__(config)
@classmethod
def config_path(cls):
return "configs/models/lxmert/pretrain.yaml"
def build(self):
if self.config.training_head_type == "pretraining":
self.model = LXMERTForPretraining(self.config)
else:
self.model = LXMERTForClassification(self.config)
if getattr(self.config, "freeze_base", False):
for p in self.model.bert.parameters():
p.requires_grad = False
def get_image_and_text_features(self, sample_list, device):
# bert input
bert_input_ids = sample_list.input_ids
bert_input_mask = sample_list.input_mask
bert_input_type_ids = sample_list.segment_ids
masked_lm_labels = sample_list.lm_label_ids
# image input
image_info = getattr(sample_list, "image_info_0", {})
image_dim_variable = getattr(image_info, "max_features", None)
image_feature_variable = getattr(sample_list, "image_feature_0", None)
max_features = torch.tensor(
image_feature_variable.shape[1], dtype=torch.int
).to(device)
image_location_variable = getattr(image_info, "bbox", None)
image_location_variable = image_location_variable[:, : max_features.item(), :4]
# aux data
image_label_variable = getattr(sample_list, "image_labels", None)
if image_label_variable is not None:
image_label_variable = image_label_variable[:, : max_features.item(), None]
image_label_variable = image_label_variable.unsqueeze(-1).to(device)
cls_prob = getattr(image_info, "cls_prob", None)
if cls_prob is not None:
cls_prob = torch.tensor(cls_prob)[:, : max_features.item(), None].to(device)
answers = getattr(sample_list, "targets", None)
if answers is None:
answers = getattr(sample_list, "answers", None)
if answers is not None:
if not isinstance(answers, torch.Tensor):
answers = torch.tensor(answers)
answers = answers.to(device)
is_correct = getattr(sample_list, "is_correct", None)
if is_correct is not None:
if isinstance(is_correct, torch.Tensor):
is_correct = is_correct.to(device)
else:
is_correct = torch.tensor(is_correct).to(device)
return {
"input_ids": bert_input_ids,
"token_type_ids": bert_input_mask,
"attention_mask": bert_input_type_ids,
"masked_lm_labels": masked_lm_labels,
"visual_feats": image_feature_variable,
"pos": image_location_variable,
"masked_image_labels": image_label_variable,
"obj_labels": cls_prob,
"matched_label": is_correct,
"ans": answers,
"image_dim": image_dim_variable,
"max_features": max_features,
"dataset_name": str(sample_list.dataset_name),
}
def get_optimizer_parameters(self, config):
return get_optimizer_parameters_for_bert(self.model, config)
def forward(self, sample_list):
device = registry.get("config").training.device
params = self.get_image_and_text_features(sample_list, device)
if params["visual_feats"] is not None and params["image_dim"] is not None:
device = params["visual_feats"].device
image_mask = (
torch.arange(params["visual_feats"].size(-2))
.expand(*params["visual_feats"].size()[:-1])
.to(device)
)
if len(params["image_dim"].size()) < len(image_mask.size()):
params["image_dim"] = params["image_dim"].unsqueeze(-1)
assert len(params["image_dim"].size()) == len(image_mask.size())
image_mask = image_mask < params["image_dim"]
params["image_attention_mask"] = image_mask.long()
else:
params["image_attention_mask"] = None
if self.config.training_head_type == "pretraining":
output_dict = self.model(
input_ids=params["input_ids"],
token_type_ids=params["token_type_ids"],
attention_mask=params["attention_mask"],
visual_feats=params["visual_feats"],
visual_pos=params["pos"],
visual_attention_mask=params["image_attention_mask"],
masked_lm_labels=params["masked_lm_labels"],
masked_image_labels=params["masked_image_labels"],
obj_labels=params["obj_labels"],
matched_label=params["matched_label"],
ans=params["ans"],
num_features=params["max_features"],
name=params["dataset_name"],
)
loss_key = "{}/{}".format(
sample_list.dataset_name, sample_list.dataset_type
)
output_dict["losses"] = {}
if "masked_lm_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/masked_lm_loss"] = output_dict.pop(
"masked_lm_loss"
)
if "matched_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/matched_loss"] = output_dict.pop(
"matched_loss"
)
if "visn_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/visn_loss"] = output_dict.pop(
"visn_loss"
)
if "answer_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/answer_loss"] = output_dict.pop(
"answer_loss"
)
else:
output_dict = self.model(
input_ids=params["input_ids"],
token_type_ids=params["token_type_ids"],
attention_mask=params["attention_mask"],
visual_feats=params["visual_feats"],
visual_pos=params["pos"],
visual_attention_mask=params["image_attention_mask"],
)
return output_dict