Revision 95887524ce10a71147304d88aeb16da72b5b399f authored by Amanpreet Singh on 25 March 2021, 03:28:33 UTC, committed by Facebook GitHub Bot on 25 March 2021, 03:29:25 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/825

After this change dataset builder will inherit from PL datamodule as at the high level both concepts are same. BaseDatasetBuilder has been adjusted to also support datamodule functionality directly.

Reviewed By: ytsheng, vedanuj

Differential Revision: D26505873

fbshipit-source-id: f3e1393f65bf25a2ec7161a5ef2e450c98a192e0
1 parent 19f3211
Raw File
optimizers.py
# Copyright (c) Facebook, Inc. and its affiliates.
import math
from typing import Callable

import torch
from mmf.common.registry import registry
from transformers.optimization import AdamW


registry.register_optimizer("adam_w")(AdamW)


@registry.register_optimizer("adam_w_skip_params_with_zero_grad")
class AdamWSkipParamsWithZeroGrad(AdamW):
    def step(self, closure: Callable = None):
        """
        Performs a single optimization step.
        Arguments:
            closure (:obj:`Callable`, `optional`): A closure that reevaluates the model
            and returns the loss.

        modified from
        https://github.com/huggingface/transformers/blob/d2f9cb838ec1ed7f62ddfb850dccd223e19441ad/src/transformers/optimization.py#L259-L318  # NoQA
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                if p.grad.abs().sum().item() == 0:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        "Adam does not support sparse gradients, please consider "
                        "SparseAdam instead"
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if group["correct_bias"]:  # No bias correction for Bert
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = (
                        step_size * math.sqrt(bias_correction2) / bias_correction1
                    )

                p.data.addcdiv_(exp_avg, denom, value=-step_size)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])

        return loss
back to top