https://github.com/facebookresearch/pythia
Tip revision: 48c0d58529a2ea71ac03b13bc9dcd7a2d6aad388 authored by omkar on 01 July 2020, 17:58:35 UTC
Working training
Working training
Tip revision: 48c0d58
fusions.py
# Copyright (c) Facebook, Inc. and its affiliates.
from copy import deepcopy
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.modules.encoders import MultiModalEncoderBase
from mmf.utils.build import build_classifier_layer
from mmf.utils.modeling import get_bert_configured_parameters
class FusionBase(MultiModalEncoderBase):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
def build(self):
encoders = self._build_encoders(self.config)
text_encoder, modal_encoder = encoders[0], encoders[1]
self._modal_encoder_config = self.config.modal_encoder
self._is_direct_features_input = self.config.direct_features_input
self._encoder_config = getattr(text_encoder, "config", None)
self.text = text_encoder
self.modal = modal_encoder
def forward(
self,
text,
modal,
text_args=None,
modal_args=None,
text_kwargs=None,
modal_kwargs=None,
):
if text_args is None:
text_args = []
if modal_args is None:
modal_args = []
if text_kwargs is None:
text_kwargs = {}
if modal_kwargs is None:
modal_kwargs = {}
text = self.text(text, *text_args, **text_kwargs)
# Case of bert encoder, we only need pooled output
if len(text) == 2:
text = text[1]
modal = self.modal(modal, *modal_args, **modal_kwargs)
modal = torch.flatten(modal, start_dim=1)
text = torch.flatten(text, start_dim=1)
return text, modal
@registry.register_model("concat_bert")
class ConcatBERT(BaseModel):
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self._is_direct_features_input = config.direct_features_input
@classmethod
def config_path(cls):
return "configs/models/fusions/concat_bert.yaml"
def build(self):
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features
# As the in_dim is dynamically calculated we need to copy classifier_config
classifier_config = deepcopy(self.config.classifier)
classifier_config.params.in_dim = num_features * self.config.modal_hidden_size
classifier_config.params.in_dim += self.config.text_hidden_size
self.classifier = build_classifier_layer(classifier_config)
if self.config.freeze_text or self.config.freeze_complete_base:
for p in self.base.text.parameters():
p.requires_grad = False
if self.config.freeze_modal or self.config.freeze_complete_base:
for p in self.base.modal.parameters():
p.requires_grad = False
def get_optimizer_parameters(self, config):
# For finetuning setup, we have classifier
lr = config.optimizer.params.lr
model_config = getattr(config.model_config, config.model, {})
finetune_lr_multiplier = getattr(model_config, "finetune_lr_multiplier", 1)
# Finetune the bert pretrained part with finetune_lr_multiplier if it is set
parameters = get_bert_configured_parameters(
self.base, lr * finetune_lr_multiplier
)
parameters += get_bert_configured_parameters(self.classifier, lr)
return parameters
def forward(self, sample_list):
text = sample_list.input_ids
mask = sample_list.input_mask
segment = sample_list.segment_ids
if self._is_direct_features_input:
modal = sample_list.image_features_0
else:
modal = sample_list.image
text_embedding, modal_embedding = self.base(text, modal, [mask, segment])
embedding = torch.cat([text_embedding, modal_embedding], dim=-1)
output = {}
output["scores"] = self.classifier(embedding)
return output
@registry.register_model("concat_bow")
class ConcatBoW(BaseModel):
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self._is_direct_features_input = config.direct_features_input
@classmethod
def config_path(cls):
return "configs/models/fusions/concat_bow.yaml"
def build(self):
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features
# As the in_dim is dynamically calculated we need to copy classifier_config
classifier_config = deepcopy(self.config.classifier)
classifier_config.params.in_dim = num_features * self.config.modal_hidden_size
classifier_config.params.in_dim += self.config.text_hidden_size
self.classifier = build_classifier_layer(classifier_config)
def forward(self, sample_list):
text = sample_list.text
if self._is_direct_features_input:
modal = sample_list.image_feature_0
else:
modal = sample_list.image
text_embedding, modal_embedding = self.base(text, modal)
embedding = torch.cat([text_embedding, modal_embedding], dim=-1)
output = {}
output["scores"] = self.classifier(embedding)
return output
@registry.register_model("late_fusion")
class LateFusion(BaseModel):
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self._is_direct_features_input = config.direct_features_input
@classmethod
def config_path(cls):
return "configs/models/fusions/late_fusion.yaml"
def build(self):
self.base = FusionBase(self.config)
num_features = self.config.num_features
if not self._is_direct_features_input:
num_features = self.config.modal_encoder.params.num_output_features
# As the in_dim is dynamically calculated we need to copy classifier_config
modal_classifier_config = deepcopy(self.config.modal_classifier)
modal_classifier_config.params.in_dim = (
num_features * self.config.modal_hidden_size
)
self.modal_classifier = build_classifier_layer(modal_classifier_config)
text_classifier_config = deepcopy(self.config.text_classifier)
text_classifier_config.params.in_dim = self.config.text_hidden_size
self.text_classifier = build_classifier_layer(text_classifier_config)
def forward(self, sample_list):
text = sample_list.input_ids
mask = sample_list.input_mask
segment = sample_list.segment_ids
if self._is_direct_features_input:
modal = sample_list.image_feature_0
else:
modal = sample_list.image
text_embedding, modal_embedding = self.base(text, modal, [mask, segment])
text = self.text_classifier(text_embedding)
modal = self.modal_classifier(modal_embedding)
output = {}
output["scores"] = (text + modal) / 2
return output