Revision b752db6fd8fb9b68003effd545023cb0dd0a9fde authored by Vedanuj Goswami on 19 March 2021, 23:50:52 UTC, committed by Facebook GitHub Bot on 19 March 2021, 23:52:20 UTC
Summary: Pull Request resolved: https://github.com/facebookresearch/mmf/pull/816 Adding a check to final validation to run only if val is in run_type Reviewed By: ytsheng, apsdehal Differential Revision: D27130474 fbshipit-source-id: 2793a1dd3f77d4d56d737781a68973634d271157
1 parent 62a50f1
bottleneck.py
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
from torchvision.models.resnet import conv1x1, conv3x3
from torchvision.ops.misc import FrozenBatchNorm2d
class ChannelPool(nn.Module):
"""Average pooling in the channel dimension"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mean(dim=1, keepdim=True)
class SEModule(nn.Module):
"""Squeeze-and-Excitation module from https://arxiv.org/pdf/1709.01507.pdf
Args:
dim: the original hidden dim.
sqrate: the squeeze rate in hidden dim.
Returns:
New features map that channels are gated
by sigmoid weights from SE module.
"""
def __init__(self, dim: int, sqrate: float):
super().__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(dim, dim // sqrate, kernel_size=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(dim // sqrate, dim, kernel_size=1, bias=False),
nn.Sigmoid(),
)
self.attn = nn.Sequential(
ChannelPool(),
nn.Conv2d(1, 1, kernel_size=7, padding=3, bias=False),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x * self.se(x)
return x * self.attn(x)
class Modulation(nn.Module):
def __init__(
self, num_features: int, num_cond_features: int, compressed: bool = True
):
super().__init__()
self.linear = nn.Linear(num_cond_features, num_features)
self.conv = (
nn.Conv2d(num_features, 256, kernel_size=1)
if compressed
else nn.Conv2d(num_features, num_features, kernel_size=1)
)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
cond = self.linear(cond).unsqueeze(2).unsqueeze(3)
return self.conv(x * cond)
class MovieBottleneck(nn.Module):
"""
Standard ResNet bottleneck with MoVie modulation in
https://arxiv.org/abs/2004.11883
The code is inspired from
https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html
"""
expansion = 4
def __init__(
self,
inplanes: int,
planes: int,
cond_planes: int = None,
stride: int = 1,
downsample: Optional[Type[nn.Module]] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Type[nn.Module]] = None,
stride_in_1x1: bool = False,
compressed: bool = True,
use_se: bool = True,
):
super().__init__()
if norm_layer is None:
self.norm_layer = FrozenBatchNorm2d
else:
self.norm_layer = norm_layer
self.cond_planes = cond_planes
self.planes = planes
self.inplanes = inplanes
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when
# stride != 1
self.conv1 = conv1x1(inplanes, self.width, stride_1x1)
self.bn1 = self.norm_layer(self.width)
self.conv2 = conv3x3(self.width, self.width, stride_3x3, groups, dilation)
self.bn2 = self.norm_layer(self.width)
self.conv3 = conv1x1(self.width, planes * self.expansion)
self.bn3 = self.norm_layer(self.planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.se = None
self.compressed = compressed
self.use_se = use_se
def init_layers(self):
if self.cond_planes:
self.cond = Modulation(
self.inplanes, self.cond_planes, compressed=self.compressed
)
self.se = SEModule(self.planes * self.expansion, 4) if self.use_se else None
def forward(
self, x: torch.Tensor, cond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
identity = x
if self.cond_planes and self.compressed:
x = self.conv1(x) + self.cond(x, cond)
elif self.cond_planes and not self.compressed:
x += self.cond(x, cond)
x = self.conv1(x)
else:
x = self.conv1(x)
out = self.bn1(x)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
shortcut = self.downsample(identity)
else:
shortcut = identity
if self.se:
out = self.se(out)
out += shortcut
out = self.relu(out)
return out, cond
Computing file changes ...