https://github.com/facebookresearch/pythia
Raw File
Tip revision: 74ad9d0a385532d77f0bb0452a130986606126d5 authored by ryan-qiyu-jiang on 01 December 2021, 21:23:41 UTC
[docs] Add pytorchvideo docs
Tip revision: 74ad9d0
hf_layers.py
# Copyright (c) Facebook, Inc. and its affiliates.

import math
from typing import List, Optional, Tuple

import torch
from mmf.utils.patch import restore_saved_modules, safecopy_modules
from torch import Tensor, nn
from transformers.modeling_bert import (
    BertAttention,
    BertEmbeddings,
    BertEncoder,
    BertLayer,
    BertModel,
    BertPooler,
    BertSelfAttention,
    BertSelfOutput,
)
from transformers.modeling_roberta import (
    RobertaAttention,
    RobertaEmbeddings,
    RobertaEncoder,
    RobertaLayer,
    RobertaModel,
    RobertaSelfAttention,
)
from transformers.modeling_utils import PreTrainedModel


patch_functions = [
    "BertEmbeddings.forward",
    "BertEncoder.forward",
    "BertLayer.forward",
    "BertAttention.forward",
    "BertSelfAttention.forward",
    "BertSelfAttention.transpose_for_scores",
    "BertModel.forward",
    "RobertaEmbeddings.forward",
    "RobertaEncoder.forward",
    "RobertaLayer.forward",
    "RobertaAttention.forward",
    "RobertaSelfAttention.forward",
    "RobertaSelfAttention.transpose_for_scores",
    "RobertaModel.forward",
]
patch_modules = [p_fun.split(".")[0] for p_fun in patch_functions]


def replace_with_jit():
    """
    Monkey patch some transformer functions to replace with scriptable ones.
    """
    # to revert monkey patch without reload()
    safecopy_modules(patch_functions, _get_modules_dict(patch_modules))

    BertEmbeddings.forward = BertEmbeddingsJit.forward
    BertEncoder.forward = BertEncoderJit.forward
    BertLayer.forward = BertLayerJit.forward
    BertAttention.forward = BertAttentionJit.forward
    BertSelfAttention.forward = BertSelfAttentionJit.forward
    BertSelfAttention.transpose_for_scores = BertSelfAttentionJit.transpose_for_scores
    BertModel.forward = BertModelJit.forward
    PreTrainedModel.__jit_unused_properties__ = [
        "base_model",
        "dummy_inputs",
        "device",
        "dtype",
    ]
    RobertaEmbeddings.forward = RobertaEmbeddingsJit.forward
    RobertaEncoder.forward = BertEncoderJit.forward
    RobertaLayer.forward = BertLayerJit.forward
    RobertaAttention.forward = BertAttentionJit.forward
    RobertaSelfAttention.forward = BertSelfAttentionJit.forward
    RobertaSelfAttention.transpose_for_scores = (
        BertSelfAttentionJit.transpose_for_scores
    )
    RobertaModel.forward = BertModelJit.forward


def undo_replace_with_jit():
    """
    Reload modules to undo monkey patch.
    """
    restore_saved_modules(_get_modules_dict(patch_modules))


def _get_modules_dict(modules_list):
    """
    Expects a list of str module names.
    Returns a dict of module_name: module obj,
    a subset of globals().
    """
    global_table = globals()
    return {module_name: global_table[module_name] for module_name in modules_list}


class BertEmbeddingsJit(BertEmbeddings):
    """
    Torchscriptable version of `BertEmbeddings` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function

    Changes to `forward` function ::
        Typed inputs and modified device to be input_ids.device by default
    """

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
    ):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttentionJit(BertSelfAttention):
    """
    Torchscriptable version of `BertSelfAttention` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function and `transpose_for_scores` function

    Changes to `transpose_for_scores` function ::
        Changes the `new_x_shape` unpacking as static size inference is not supported

    Changes to `forward` function ::
        Uses scriptable `nn.functional.softmax` and also removes several static size
        inference which is not supported.
    """

    def transpose_for_scores(self, x: Tensor) -> Tensor:
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Optional[Tensor] = None,
        head_mask: Optional[Tensor] = None,
        encoder_hidden_states: Optional[Tensor] = None,
        encoder_attention_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention
        # scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel
            # forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs)
        return outputs


class BertAttentionJit(BertAttention):
    """
    Torchscriptable version of `BertAttention` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function as well as uses scriptable `BertSelfAttentionJit`

    Changes to `forward` function ::
        Typed inputs and modifies the output to be a List[Tensor]
    """

    def __init__(self, config):
        super().__init__(config)
        self.self = BertSelfAttentionJit(config)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Optional[Tensor] = None,
        head_mask: Optional[Tensor] = None,
        encoder_hidden_states: Optional[Tensor] = None,
        encoder_attention_mask: Optional[Tensor] = None,
    ) -> List[Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[
            1:
        ]  # add attentions if we output them
        return outputs


class BertLayerJit(BertLayer):
    """
    Torchscriptable version of `BertLayer` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function as well as uses scriptable `BertAttentionJit`

    Changes to `forward` function::
        Typed inputs and modifies the output to be a List[Tensor]
    """

    def __init__(self, config):
        super().__init__(config)
        self.attention = BertAttentionJit(config)
        self.is_decoder = config.is_decoder
        if self.is_decoder:
            self.crossattention = BertAttentionJit(config)

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Optional[Tensor] = None,
        head_mask: Optional[Tensor] = None,
        encoder_hidden_states: Optional[Tensor] = None,
        encoder_attention_mask: Optional[Tensor] = None,
    ) -> List[Tensor]:
        self_attention_outputs = self.attention(
            hidden_states, attention_mask, head_mask
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[
            1:
        ]  # add self attentions if we output attention weights

        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        outputs = (layer_output,) + outputs
        return outputs


class BertEncoderJit(BertEncoder):
    """
    Torchscriptable version of `BertEncoder` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function as well as uses scriptable `BertLayerJit`

    Changes to `forward` function::
        Typed inputs and modifies the output to be of Tuple[Tensor] type in scripting
        mode. Due to different possible types when `output_hidden_states` or
        `output_attentions` are enable, we do not support these in scripting mode
    """

    def __init__(self, config):
        super().__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList(
            [BertLayerJit(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Optional[Tensor],
        encoder_hidden_states: Optional[Tensor] = None,
        encoder_attention_mask: Optional[Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = False,
        head_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor]:
        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if not torch.jit.is_scripting() and output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                None,
                encoder_hidden_states,
                encoder_attention_mask,
            )
            hidden_states = layer_outputs[0]

            if not torch.jit.is_scripting() and output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if not torch.jit.is_scripting() and output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if not torch.jit.is_scripting():
            if output_hidden_states:
                outputs = outputs + (all_hidden_states,)
            if output_attentions:
                outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)


class BertModelJit(BertModel):
    """
    Torchscriptable version of `BertModel` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_bert.py # noqa

    Modifies the `forward` function

    Changes to `forward` function ::
        Typings for input, modifications to device, change output type to
        Tuple[Tensor, Tensor, List[Tensor]]
    """

    __jit_unused_properties__ = ["base_model", "dummy_inputs", "device", "dtype"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embeddings = BertEmbeddingsJit(config)
        self.encoder = BertEncoderJit(config)
        self.pooler = BertPooler(config)

        self.init_weights()

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Optional[Tensor] = None,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        head_mask: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
        encoder_hidden_states: Optional[Tensor] = None,
        encoder_attention_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, List[Tensor]]:
        """Forward pass on the Model.
        The model can behave as an encoder (with only self-attention) as well
        as a decoder, in which case a layer of cross-attention is added between
        the self-attention layers, following the architecture described in
        `Attention is all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar,
        Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and
        Illia Polosukhin.
        To behave as an decoder the model needs to be initialized with the
        `is_decoder` argument of the configuration set to `True`; an
        `encoder_hidden_states` is expected as an input to the forward pass.
        .. _`Attention is all you need`:
            https://arxiv.org/abs/1706.03762
        """
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = inputs_embeds.device if inputs_embeds is not None else input_ids.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions
        # [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to
            # the padding mask
            # - if the model is an encoder, make the mask broadcastable to
            # [batch_size, num_heads, seq_length, seq_length]
            extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or "
                + f"attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        # Python builtin next is currently not supported in Torchscript
        if not torch.jit.is_scripting():
            extended_attention_mask = extended_attention_mask.to(
                dtype=next(self.parameters()).dtype
            )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # If a 2D ou 3D attention mask is provided for the cross-attention
        # we need to make broadcastabe to
        # [batch_size, num_heads, seq_length, seq_length]
        encoder_extended_attention_mask = None

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        # add hidden_states and attentions if they are here
        outputs = (sequence_output, pooled_output, encoder_outputs[1:])
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class RobertaEmbeddingsJit(RobertaEmbeddings):
    """
    Torchscriptable version of `RobertaEmbeddings` from Huggingface transformers v2.3.0
    https://github.com/huggingface/transformers/blob/v2.3.0/transformers/modeling_roberta.py # noqa

    Modifies the `forward` function

    Changes to `forward` function ::
        Typed inputs and modified device to be input_ids.device by default
    """

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
    ):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = inputs_embeds.device if inputs_embeds is not None else input_ids.device

        if position_ids is None:
            # Position numbers begin at padding_idx+1. Padding symbols are ignored.
            # cf. fairseq's `utils.make_positions`
            position_ids = torch.arange(
                self.padding_idx + 1,
                seq_length + self.padding_idx + 1,
                dtype=torch.long,
                device=device,
            )
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
back to top