https://github.com/facebookresearch/pythia
Tip revision: 1ea94d293e8a6114402e20b4fbc376940d13f9e5 authored by ryan-qiyu-jiang on 16 December 2021, 23:18:48 UTC
Update on "[docs] Update UNITER project doc"
Update on "[docs] Update UNITER project doc"
Tip revision: 1ea94d2
butd.py
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
from mmf.common.registry import registry
from mmf.models.pythia import Pythia
from mmf.modules.layers import ClassifierLayer
@registry.register_model("butd")
class BUTD(Pythia):
def __init__(self, config):
super().__init__(config)
@classmethod
def config_path(cls):
return "configs/models/butd/defaults.yaml"
def build(self):
self._build_word_embedding()
self._init_feature_encoders("image")
self._init_feature_embeddings("image")
self._init_classifier()
self._init_extras()
def _build_word_embedding(self):
self.text_processor = registry.get(self._datasets[0] + "_text_processor")
self.vocab = self.text_processor.vocab
self.vocab_size = self.vocab.get_size()
self.word_embedding = self.vocab.get_embedding(
torch.nn.Embedding, embedding_dim=self.config.embedding_dim
)
self.text_embeddings_out_dim = self.config.embedding_dim
def _init_classifier(self):
self.classifier = ClassifierLayer(
self.config.classifier.type,
in_dim=self.config.classifier.params.feature_dim,
out_dim=self.vocab_size,
**self.config.classifier.params,
)
def get_optimizer_parameters(self, config):
params = [
{"params": self.word_embedding.parameters()},
{"params": self.image_feature_embeddings_list.parameters()},
{"params": self.classifier.parameters()},
{
"params": self.image_feature_encoders.parameters(),
"lr": (config.optimizer.params.lr * 0.1),
},
]
return params
def prepare_data(self, sample_list, batch_size):
# turn off teacher forcing during beam search
# (otherwise one cannot run beam search on val set)
self.teacher_forcing = self.config.inference.type != "beam_search" and hasattr(
sample_list, "text"
)
data = {}
if self.teacher_forcing:
caption_lengths, sort_ind = sample_list.caption_len.sort(
dim=0, descending=True
)
data["decode_lengths"] = (caption_lengths - 1).tolist()
sample_list.text = sample_list.text[sort_ind]
sample_list.answers = sample_list.answers[sort_ind]
sample_list.image_feature_0 = sample_list.image_feature_0[sort_ind]
data["texts"] = sample_list.text
timesteps = max(data["decode_lengths"])
sample_list.add_field("targets", sample_list.text[:, 1:])
else:
data["texts"] = sample_list.answers.new_full(
(batch_size, 1), self.vocab.SOS_INDEX, dtype=torch.long
)
timesteps = self.text_processor.max_length
sample_list.add_field("targets", sample_list.answers[:, 0, 1:])
return data, sample_list, timesteps
def init_hidden_state(self, features):
h = features.new_zeros(
(features.size(0), self.config.classifier.params.hidden_dim),
dtype=torch.float,
)
c = features.new_zeros(
(features.size(0), self.config.classifier.params.hidden_dim),
dtype=torch.float,
)
return h, c
def get_data_t(self, t, data, batch_size_t, prev_output):
if self.teacher_forcing:
# Modify batch_size for timestep t
batch_size_t = sum([l > t for l in data["decode_lengths"]])
elif prev_output is not None and self.config.inference.type == "greedy":
# Adding t-1 output words to data["text"] for greedy decoding
output_softmax = torch.log_softmax(prev_output, dim=1)
_, indices = torch.max(output_softmax, dim=1, keepdim=True)
data["texts"] = torch.cat(
(data["texts"], indices.view(batch_size_t, 1)), dim=1
)
# Slice data based on batch_size at timestep t
data["texts"] = data["texts"][:batch_size_t]
if "state" in data:
h1 = data["state"]["td_hidden"][0][:batch_size_t]
c1 = data["state"]["td_hidden"][1][:batch_size_t]
h2 = data["state"]["lm_hidden"][0][:batch_size_t]
c2 = data["state"]["lm_hidden"][1][:batch_size_t]
else:
h1, c1 = self.init_hidden_state(data["texts"])
h2, c2 = self.init_hidden_state(data["texts"])
data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
registry.register(f"{h1.device}_lstm_state", data["state"])
return data, batch_size_t
def forward(self, sample_list):
# Stores the output probabilites.
scores = sample_list.answers.new_ones(
(
sample_list.answers.size(0),
self.text_processor.max_length,
self.vocab_size,
),
dtype=torch.float,
)
if self.config["inference"]["type"] in ["beam_search", "nucleus_sampling"]:
decoder = registry.get_decoder_class(self.config["inference"]["type"])(
self.vocab, self.config
)
sample_list = decoder.init_batch(sample_list)
batch_size = sample_list.image_feature_0.size(0)
data, sample_list, timesteps = self.prepare_data(sample_list, batch_size)
output = None
batch_size_t = batch_size
for t in range(timesteps):
data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
pi_t = data["texts"]
else:
pi_t = data["texts"][:, t].unsqueeze(-1)
embedding = self.word_embedding(pi_t)
attention_feature, _ = self.process_feature_embedding(
"image", sample_list, embedding[:, 0, :], batch_size_t=batch_size_t
)
output = self.classifier(attention_feature)
# Compute decoding
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
finish, data, batch_size_t = decoder.decode(t, data, output)
if finish:
break
else:
scores[:batch_size_t, t] = output
model_output = {}
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
results = decoder.get_result()
results = torch.nn.functional.pad(
results,
(0, self.text_processor.max_length - results.size()[-1]),
"constant",
0,
)
model_output["captions"] = results
model_output["losses"] = {}
loss_key = "{}/{}".format(
sample_list.dataset_name, sample_list.dataset_type
)
# Add a dummy loss so that loss calculation is not required
model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros(
batch_size, device=sample_list.answers.device
)
else:
model_output["scores"] = scores
return model_output