https://github.com/facebookresearch/pythia
Tip revision: 77a14f4803a665fb47d2dc9214c0bd310229b522 authored by omkar on 10 July 2020, 03:57:59 UTC
working code
working code
Tip revision: 77a14f4
attention.py
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
from torch import nn
from mmf.modules.layers import GatedTanh, ModalCombineLayer, TransformLayer
class AttentionLayer(nn.Module):
def __init__(self, image_dim, question_dim, **kwargs):
super().__init__()
combine_type = kwargs["modal_combine"]["type"]
combine_params = kwargs["modal_combine"]["params"]
modal_combine_layer = ModalCombineLayer(
combine_type, image_dim, question_dim, **combine_params
)
transform_type = kwargs["transform"]["type"]
transform_params = kwargs["transform"]["params"]
transform_layer = TransformLayer(
transform_type, modal_combine_layer.out_dim, **transform_params
)
normalization = kwargs["normalization"]
self.module = TopDownAttention(
modal_combine_layer, transform_layer, normalization
)
if hasattr(self.module, "out_dim"):
self.out_dim = self.module.out_dim
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class ConcatenationAttention(nn.Module):
def __init__(self, image_feat_dim, txt_rnn_embeding_dim, hidden_size):
super().__init__()
self.image_feat_dim = image_feat_dim
self.txt_embeding_dim = txt_rnn_embeding_dim
self.fa = GatedTanh(image_feat_dim + txt_rnn_embeding_dim, hidden_size)
self.lc = nn.Linear(hidden_size, 1)
def forward(self, image_feat, question_embedding):
_, num_location, _ = image_feat.shape
question_embedding_expand = torch.unsqueeze(question_embedding, 1).expand(
-1, num_location, -1
)
concat_feature = torch.cat((image_feat, question_embedding_expand), dim=2)
raw_attention = self.lc(self.fa(concat_feature))
# softmax across locations
attention_weights = nn.functional.softmax(raw_attention, dim=1)
attention_weights = attention_weights.expand_as(image_feat)
return attention_weights
class ProjectAttention(nn.Module):
def __init__(self, image_feat_dim, txt_rnn_embeding_dim, hidden_size, dropout=0.2):
super().__init__()
self.image_feat_dim = image_feat_dim
self.txt_embeding_dim = txt_rnn_embeding_dim
self.fa_image = GatedTanh(image_feat_dim, hidden_size)
self.fa_txt = GatedTanh(txt_rnn_embeding_dim, hidden_size)
self.dropout = nn.Dropout(dropout)
self.lc = nn.Linear(hidden_size, 1)
def compute_raw_att(self, image_feat, question_embedding):
num_location = image_feat.shape[1]
image_fa = self.fa_image(image_feat)
question_fa = self.fa_txt(question_embedding)
question_fa_expand = torch.unsqueeze(question_fa, 1).expand(
-1, num_location, -1
)
joint_feature = image_fa * question_fa_expand
joint_feature = self.dropout(joint_feature)
raw_attention = self.lc(joint_feature)
return raw_attention
def forward(self, image_feat, question_embedding):
raw_attention = self.compute_raw_att(image_feat, question_embedding)
# softmax across locations
attention_weights = nn.functional.softmax(raw_attention, dim=1)
attention_weights = attention_weights.expand_as(image_feat)
return attention_weights
class DoubleProjectAttention(nn.Module):
def __init__(self, image_feat_dim, txt_rnn_embeding_dim, hidden_size, dropout=0.2):
super().__init__()
self.att1 = ProjectAttention(
image_feat_dim, txt_rnn_embeding_dim, hidden_size, dropout
)
self.att2 = ProjectAttention(
image_feat_dim, txt_rnn_embeding_dim, hidden_size, dropout
)
self.image_feat_dim = image_feat_dim
self.txt_embeding_dim = txt_rnn_embeding_dim
def forward(self, image_feat, question_embedding):
att1 = self.att1.compute_raw_att(image_feat, question_embedding)
att2 = self.att2.compute_raw_att(image_feat, question_embedding)
raw_attn_weights = att1 + att2
# softmax across locations
attention_weights = nn.functional.softmax(raw_attn_weights, dim=1)
attention_weights = attention_weights.expand_as(image_feat)
return attention_weights
class TopDownAttention(nn.Module):
EPS = 1.0e-08
def __init__(self, combination_layer, transform_module, normalization):
super().__init__()
self.combination_layer = combination_layer
self.normalization = normalization
self.transform = transform_module
self.out_dim = self.transform.out_dim
@staticmethod
def _mask_attentions(attention, image_locs):
batch_size, num_loc, n_att = attention.size()
tmp1 = attention.new_zeros(num_loc)
tmp1[:num_loc] = torch.arange(0, num_loc, dtype=attention.dtype).unsqueeze(
dim=0
)
tmp1 = tmp1.expand(batch_size, num_loc)
tmp2 = image_locs.type(tmp1.type())
tmp2 = tmp2.unsqueeze(dim=1).expand(batch_size, num_loc)
mask = torch.ge(tmp1, tmp2)
mask = mask.unsqueeze(dim=2).expand_as(attention)
attention = attention.masked_fill(mask, 0)
return attention
def forward(self, image_feat, question_embedding, image_locs=None):
# N x K x joint_dim
joint_feature = self.combination_layer(image_feat, question_embedding)
# N x K x n_att
raw_attn = self.transform(joint_feature)
if self.normalization.lower() == "softmax":
attention = nn.functional.softmax(raw_attn, dim=1)
if image_locs is not None:
masked_attention = self._mask_attentions(attention, image_locs)
masked_attention_sum = torch.sum(masked_attention, dim=1, keepdim=True)
masked_attention_sum += masked_attention_sum.eq(0).float() + self.EPS
masked_attention = masked_attention / masked_attention_sum
else:
masked_attention = attention
elif self.normalization.lower() == "sigmoid":
attention = torch.sigmoid(raw_attn)
masked_attention = attention
if image_locs is not None:
masked_attention = self._mask_attentions(attention, image_locs)
return masked_attention