https://github.com/facebookresearch/pythia
Raw File
Tip revision: dabf95f523cd07e93380c6931e5140ade0f50b2f authored by Sethu Sankaran on 26 October 2021, 19:18:43 UTC
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Tip revision: dabf95f
encoders.py
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import pickle
import re
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Any

import torch
import torchvision
from mmf.common.registry import registry
from mmf.models.frcnn import GeneralizedRCNN
from mmf.modules.embeddings import ProjectionEmbedding, TextEmbedding
from mmf.modules.hf_layers import BertModelJit
from mmf.modules.layers import Identity
from mmf.utils.build import build_image_encoder, build_text_encoder
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_absolute_path
from mmf.utils.logger import log_class_usage
from omegaconf import MISSING, OmegaConf
from torch import Tensor, nn
from transformers.configuration_auto import AutoConfig
from transformers.modeling_auto import AutoModel


try:
    from detectron2.modeling import ShapeSpec, build_resnet_backbone
except ImportError:
    pass


class Encoder(nn.Module):
    @dataclass
    class Config:
        name: str = MISSING

    def __init__(self):
        super().__init__()
        log_class_usage("Encoder", self.__class__)

    @classmethod
    def from_params(cls, **kwargs):
        config = OmegaConf.structured(cls.Config(**kwargs))
        return cls(config)


class EncoderFactory(nn.Module):
    @dataclass
    class Config:
        type: str = MISSING
        params: Encoder.Config = MISSING


class ImageFeatureEncoderTypes(Enum):
    default = "default"
    identity = "identity"
    projection = "projection"
    frcnn_fc7 = "finetune_faster_rcnn_fpn_fc7"


class ImageFeatureEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        in_dim: int = MISSING


class ImageFeatureEncoderFactory(EncoderFactory):
    @dataclass
    class Config(EncoderFactory.Config):
        type: ImageFeatureEncoderTypes = MISSING
        params: ImageFeatureEncoder.Config = MISSING

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        encoder_type = config.type
        if isinstance(encoder_type, ImageFeatureEncoderTypes):
            encoder_type = encoder_type.value

        assert (
            "in_dim" in config.params
        ), "ImageFeatureEncoder require 'in_dim' param in config"
        params = config.params

        if encoder_type == "default" or encoder_type == "identity":
            self.module = Identity()
            self.module.in_dim = params.in_dim
            self.module.out_dim = params.in_dim
        elif encoder_type == "projection":
            if "module" not in params:
                params = deepcopy(params)
                params.module = "linear"
            self.module = ProjectionEmbedding(**params)
        elif encoder_type == "finetune_faster_rcnn_fpn_fc7":
            self.module = FinetuneFasterRcnnFpnFc7(params)
        else:
            raise NotImplementedError("Unknown Image Encoder: %s" % encoder_type)

        self.out_dim = self.module.out_dim

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


@registry.register_encoder("finetune_faster_rcnn_fpn_fc7")
class FinetuneFasterRcnnFpnFc7(ImageFeatureEncoder):
    @dataclass
    class Config(ImageFeatureEncoder.Config):
        name: str = "finetune_faster_rcnn_fpn_fc7"
        in_dim: int = MISSING
        weights_file: str = "fc7_w.pkl"
        bias_file: str = "fc7_b.pkl"
        model_data_dir: str = MISSING

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        model_data_dir = get_absolute_path(config.model_data_dir)

        if not os.path.isabs(config.weights_file):
            weights_file = os.path.join(model_data_dir, config.weights_file)
        if not os.path.isabs(config.bias_file):
            bias_file = os.path.join(model_data_dir, config.bias_file)

        if not PathManager.exists(bias_file) or not PathManager.exists(weights_file):
            download_path = download_pretrained_model("detectron.vmb_weights")
            weights_file = get_absolute_path(os.path.join(download_path, "fc7_w.pkl"))
            bias_file = get_absolute_path(os.path.join(download_path, "fc7_b.pkl"))

        with PathManager.open(weights_file, "rb") as w:
            weights = pickle.load(w)
        with PathManager.open(bias_file, "rb") as b:
            bias = pickle.load(b)
        out_dim = bias.shape[0]

        self.lc = nn.Linear(config.in_dim, out_dim)
        self.lc.weight.data.copy_(torch.from_numpy(weights))
        self.lc.bias.data.copy_(torch.from_numpy(bias))
        self.out_dim = out_dim

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        old_prefix = prefix + "module."
        for k in list(state_dict.keys()):
            if k.startswith(old_prefix):
                new_k = k.replace(old_prefix, prefix)
                state_dict[new_k] = state_dict.pop(k)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def forward(self, image):
        i2 = self.lc(image)
        i3 = nn.functional.relu(i2)
        return i3


@registry.register_encoder("identity")
class IdentityEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "identity"
        # Random in_dim if not specified
        in_dim: int = 100

    def __init__(self, config: Config):
        super().__init__()
        self.module = nn.Identity()
        self.in_dim = config.get("in_dim", 100)
        self.out_dim = self.in_dim

    def forward(self, x):
        return self.module(x)


class ImageEncoderTypes(Enum):
    default = "default"
    identity = "identity"
    torchvision_resnet = "torchvision_resnet"
    resnet152 = "resnet152"
    detectron2_resnet = "detectron2_resnet"


class ImageEncoderFactory(EncoderFactory):
    @dataclass
    class Config(EncoderFactory.Config):
        type: ImageEncoderTypes = MISSING

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self._type = config.type

        if isinstance(self._type, ImageEncoderTypes):
            self._type = self._type.value

        params = config.params

        if self._type == "default" or self._type == "identity":
            self.module = nn.Identity()
            self.module.out_dim = params.in_dim
        elif self._type == "resnet152":
            self.module = ResNet152ImageEncoder(params)
        elif self._type == "torchvision_resnet":
            self.module = TorchvisionResNetImageEncoder(params)
        elif self._type == "detectron2_resnet":
            self.module = Detectron2ResnetImageEncoder(params)
        elif self._type == "frcnn":
            self.module = FRCNNImageEncoder(params)
        else:
            raise NotImplementedError("Unknown Image Encoder: %s" % self._type)

    @property
    def out_dim(self):
        return self.module.out_dim

    def forward(self, image):
        return self.module(image)


# Taken from facebookresearch/mmbt with some modifications
@registry.register_encoder("resnet152")
class ResNet152ImageEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "resnet152"
        pretrained: bool = True
        # "avg" or "adaptive"
        pool_type: str = "avg"
        num_output_features: int = 1

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        model = torchvision.models.resnet152(pretrained=config.get("pretrained", True))
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)

        pool_func = (
            nn.AdaptiveAvgPool2d if config.pool_type == "avg" else nn.AdaptiveMaxPool2d
        )

        # -1 will keep the original feature size
        if config.num_output_features == -1:
            self.pool = nn.Identity()
        elif config.num_output_features in [1, 2, 3, 5, 7]:
            self.pool = pool_func((config.num_output_features, 1))
        elif config.num_output_features == 4:
            self.pool = pool_func((2, 2))
        elif config.num_output_features == 6:
            self.pool = pool_func((3, 2))
        elif config.num_output_features == 8:
            self.pool = pool_func((4, 2))
        elif config.num_output_features == 9:
            self.pool = pool_func((3, 3))

        self.out_dim = 2048

    def forward(self, x):
        # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
        out = self.pool(self.model(x))
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()
        return out  # BxNx2048


@registry.register_encoder("torchvision_resnet")
class TorchvisionResNetImageEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "resnet50"
        pretrained: bool = False
        zero_init_residual: bool = True
        num_output_features: int = -1
        pool_type: str = "avg"

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config

        model = getattr(torchvision.models, config.name)(
            pretrained=config.pretrained, zero_init_residual=config.zero_init_residual
        )

        # checks if use_avgpool exists to maintain the old logic
        self.use_avgpool = config.get("use_avgpool", None)
        if self.use_avgpool:  # use_avgpool is True
            config.num_output_features = 1
            config.pool_type = "avg"
        elif self.use_avgpool is False:  # use_avgpool is False
            config.num_output_features = -1

        if config.pretrained:
            model = self._load_pretrained(model, config)

        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.pool = self._pool_func(config)
        self.out_dim = config.get("out_dim", 2048)

    def _load_pretrained(self, model, config: Config):
        pretrained_model = config.get("pretrained_model", "supervised")
        if pretrained_model == "supervised":
            pass  # this is already loaded via torchvision using pretrained=True
        elif os.path.exists(pretrained_model):
            model.load_state_dict(torch.load(pretrained_model))
        else:
            try:
                with PathManager.open(pretrained_model, "rb") as f:
                    model.load_state_dict(
                        torch.load(f, map_location=lambda storage, loc: storage),
                        strict=False,
                    )
            except Exception:
                raise Exception(f"unknown pretrained ResNet model: {pretrained_model}")
        return model

    def _pool_func(self, config: Config):
        pool_func = (
            nn.AdaptiveAvgPool2d if config.pool_type == "avg" else nn.AdaptiveMaxPool2d
        )
        # -1 will keep the original feature size
        if config.num_output_features == -1:
            pool = nn.Identity()
        elif config.num_output_features in [1, 2, 3, 5, 7]:
            pool = pool_func((config.num_output_features, 1))
        elif config.num_output_features == 4:
            pool = pool_func((2, 2))
        elif config.num_output_features == 6:
            pool = pool_func((3, 2))
        elif config.num_output_features == 8:
            pool = pool_func((4, 2))
        elif config.num_output_features == 9:
            pool = pool_func((3, 3))

        return pool

    def forward(self, x):
        # B x 3 x 224 x 224 -> B x out_dim x 7 x 7
        out = self.pool(self.model(x))
        if self.use_avgpool is None:
            out = torch.flatten(out, start_dim=2)
            out = out.transpose(1, 2).contiguous()  # BxNxout_dim
        else:
            out = torch.flatten(out, start_dim=1)  # BxN*out_dim
        return out


@registry.register_encoder("detectron2_resnet")
class Detectron2ResnetImageEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "detectron2_resnet"
        pretrained: bool = True
        pretrained_path: str = None

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        pretrained = config.get("pretrained", False)
        pretrained_path = config.get("pretrained_path", None)

        self.resnet = build_resnet_backbone(config, ShapeSpec(channels=3))

        if pretrained:
            state_dict = torch.hub.load_state_dict_from_url(
                pretrained_path, progress=False
            )
            new_state_dict = OrderedDict()
            replace_layer = {"backbone.": ""}

            for key, value in state_dict["model"].items():
                new_key = re.sub(
                    r"(backbone\.)", lambda x: replace_layer[x.groups()[0]], key
                )
                new_state_dict[new_key] = value
            self.resnet.load_state_dict(new_state_dict, strict=False)

        self.out_dim = 2048

    def forward(self, x):
        x = self.resnet(x)
        return x["res5"]


@registry.register_encoder("frcnn")
class FRCNNImageEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "frcnn"
        pretrained: bool = True
        pretrained_path: str = None

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        pretrained = config.get("pretrained", False)
        pretrained_path = config.get("pretrained_path", None)
        self.frcnn = GeneralizedRCNN(config)
        if pretrained:
            state_dict = torch.load(pretrained_path)
            self.frcnn.load_state_dict(state_dict)
            self.frcnn.eval()

    def forward(
        self,
        x: torch.Tensor,
        sizes: torch.Tensor = None,
        scales_yx: torch.Tensor = None,
        padding: torch.Tensor = None,
        max_detections: int = 0,
        return_tensors: str = "pt",
    ):
        x = self.frcnn(
            x,
            sizes,
            scales_yx=scales_yx,
            padding=padding,
            max_detections=max_detections,
            return_tensors=return_tensors,
        )
        return x


class TextEncoderTypes(Enum):
    identity = "identity"
    transformer = "transformer"
    embedding = "embedding"


class TextEncoderFactory(EncoderFactory):
    @dataclass
    class Config(EncoderFactory.Config):
        # identity, transformer or embedding as of now
        type: TextEncoderTypes = MISSING
        params: Encoder.Config = MISSING

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self._type = config.type
        if isinstance(self._type, TextEncoderTypes):
            self._type = self._type.value

        if self._type == "identity":
            self.module = nn.Identity()
        elif self._type == "transformer":
            self._module = TransformerEncoder(config.params)
            self.module = self._module.module
        elif self._type == "embedding":
            self.module = TextEmbeddingEncoder(config.params)
        else:
            raise NotImplementedError(f"Unknown Text Encoder {self._type}")

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


@registry.register_encoder("text_embedding")
class TextEmbeddingEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "text_embedding"
        operator: str = MISSING
        # Keeping this Any for now as this
        # needs a separate refactor PR.
        embedding_params: Any = MISSING

    def __init__(self, config: Config):
        super().__init__()
        self._operator = config.operator
        self._embedding_params = config.embedding_params

        self.module = TextEmbedding(
            self._embedding_params.type, **self._embedding_params.params
        )

    def forward(self, x):
        x = self.module(x)
        if self._operator == "sum":
            x = x.sum(dim=1)
        elif self._operator == "concat":
            x = torch.cat(x, dim=1)
        elif self._operator == "mul":
            x = torch.prod(x, dim=1)

        return x.squeeze()


@registry.register_encoder("transformer")
class TransformerEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "transformer"
        num_segments: int = 2
        bert_model_name: str = "bert-base-uncased"
        # Options below can be overridden to update the bert configuration used
        # to initialize the bert encoder. If some option is missing or
        # if you are using an encoder different then BERT, add extra parameters
        # by inheriting and extending this config
        # Those options will automatically override the options for your transformer
        # encoder's configuration. For e.g. vocab_size is missing here, just add
        # vocab_size: x to update the size of the vocabulary with which encoder is
        # initialized. If you update the default values, the transformer you
        # will get will be initialized from scratch.
        hidden_size: int = 768
        num_hidden_layers: int = 12
        num_attention_heads: int = 12
        output_attentions: bool = False
        output_hidden_states: bool = False
        random_init: bool = False

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        hf_params = {"config": self._build_encoder_config(config)}
        should_random_init = self.config.get("random_init", False)

        # For BERT models, initialize using Jit version
        if self.config.bert_model_name.startswith("bert-"):
            if should_random_init:
                self.module = BertModelJit(**hf_params)
            else:
                self.module = BertModelJit.from_pretrained(
                    self.config.bert_model_name, **hf_params
                )
        else:
            if should_random_init:
                self.module = AutoModel.from_config(**hf_params)
            else:
                self.module = AutoModel.from_pretrained(
                    self.config.bert_model_name, **hf_params
                )

        self.embeddings = self.module.embeddings
        self.original_config = self.config
        self.config = self.module.config
        self._init_segment_embeddings()

    def _init_segment_embeddings(self):
        if self.original_config.get("num_segments", None):
            num_segments = self.original_config.num_segments
            if hasattr(self.embeddings, "token_type_embeddings"):
                new_embeds = nn.Embedding(num_segments, self.config.hidden_size)
                new_embeds.weight.data[:2].copy_(
                    self.embeddings.token_type_embeddings.weight
                )
                for idx in range(2, num_segments - 1):
                    new_embeds.weight.data[idx].copy_(
                        self.embeddings.token_type_embeddings.weight.data.mean(dim=0)
                    )
                self.embeddings.token_type_embeddings = new_embeds

    def _build_encoder_config(self, config: Config):
        return AutoConfig.from_pretrained(
            config.bert_model_name, **OmegaConf.to_container(config)
        )

    def forward(self, *args, return_sequence=False, **kwargs) -> Tensor:
        # Only return pooled output
        output = self.module(*args, **kwargs)
        return output[0] if return_sequence else output[1]


class MultiModalEncoderBase(Encoder):
    __jit_unused_properties__ = ["encoder_config"]

    @dataclass
    class Config(Encoder.Config):
        # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
        modal_encoder: EncoderFactory.Config = ImageEncoderFactory.Config(
            type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config()
        )
        text_encoder: EncoderFactory.Config = TextEncoderFactory.Config(
            type=TextEncoderTypes.transformer, params=TransformerEncoder.Config()
        )
        direct_features_input: bool = False
        modal_hidden_size: int = 2048
        text_hidden_size: int = 768

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config

        self._modal_encoder_config = self.config.get("modal_encoder", None)

        self._is_direct_features_input = self.config.get("direct_features_input", False)

        self.build()
        self.modal_hidden_size = self.config.get("modal_hidden_size", None)
        self.text_hidden_size = self.config.get("text_hidden_size", None)

    def build(self):
        encoders = self._build_encoders(self.config)
        self.text_encoder, self.modal_encoder = encoders[0], encoders[1]

        self._encoder_config = None
        if self.text_encoder:
            self._encoder_config = self.text_encoder.config

    @property
    def encoder_config(self):
        return self._encoder_config

    def _build_encoders(self, config):
        text_encoder = None
        if config.get("text_encoder", None):
            text_encoder = build_text_encoder(config.text_encoder)

        modal_encoder = None
        if config.get("modal_encoder", None):
            modal_encoder = self._build_modal_encoder(config.modal_encoder)

        return (text_encoder, modal_encoder)

    def _build_modal_encoder(self, config):
        return build_image_encoder(
            config, direct_features=self._is_direct_features_input
        )


class PooledEncoder(Encoder):
    """
    Standard pooled encoder class which takes in an input, encodes it with an encoder
    implemented and returned from `self.build_encoder` function, pools it based
    `pool_type` and `num_output_features` specified, flattens it and returns it
    back as a tensor.
    """

    @dataclass
    class Config(Encoder.Config):
        num_output_features: int = 1  # How many output features need to be returned.
        pool_type: str = "avg"  # type of pooling to apply "avg" | "adaptive"
        out_dim: int = MISSING  # size of out dim expected
        three_d: bool = False  # if input requires 3D pooling (for video)

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.encoder = self.build_encoder(config)
        pool_func = (
            nn.AdaptiveAvgPool2d if config.pool_type == "avg" else nn.AdaptiveMaxPool2d
        )
        params = (config.num_output_features, 1)
        if config.three_d:
            pool_func = (
                nn.AdaptiveAvgPool3d
                if config.pool_type == "avg"
                else nn.AdaptiveMaxPool3d
            )
            params = (config.num_output_features, 1, 1)
        # -1 will keep the original feature size
        if config.num_output_features == -1:
            self.pool = nn.Identity()
        else:
            self.pool = pool_func(params)
        self.out_dim = config.out_dim

    def build_encoder(self, config: Config, *args, **kwargs):
        """Build an encoder on whose output the pooling will be applied.

        Args:
            config (Config): Config parameter required to build the encoder.

        Raises:
            NotImplementedError: Not implemented by default.
        """
        raise NotImplementedError()

    def forward(self, x: Tensor) -> Tensor:
        out = self.encoder(x)
        out = self.pool(out)
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()
        return out


@registry.register_encoder("r2plus1d_18")
class R2Plus1D18VideoEncoder(PooledEncoder):
    """
    R2Plus1D based video encoder. Returns back a tensor of dim 2048.
    By default, pretrained version is used.
    See https://arxiv.org/abs/1711.11248.
    """

    @dataclass
    class Config(PooledEncoder.Config):
        name: str = "r2plus1d_18"
        out_dim: int = 512  # out dim
        pretrained: bool = True  # if should use pretrained version or not
        three_d: bool = True

    def build_encoder(self, config: Config, *args, **kwargs):
        model = torchvision.models.video.r2plus1d_18(
            pretrained=config.get("pretrained", True)
        )
        modules = list(model.children())[:-2]
        return nn.Sequential(*modules)


@registry.register_encoder("resnet18_audio")
class ResNet18AudioEncoder(PooledEncoder):
    """
    Audio encoder based on ResNet18 used in various audio classification paper
    as a baseline. By default, not pretrained version is used.
    """

    @dataclass
    class Config(PooledEncoder.Config):
        name: str = "resnet18_audio"
        out_dim: int = 512
        pretrained: bool = False

    def build_encoder(self, config: Config, *args, **kwargs):
        model = torchvision.models.resnet18(pretrained=config.get("pretrained", False))
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        modules = list(model.children())[:-2]
        return nn.Sequential(*modules)
back to top