https://github.com/facebookresearch/pythia
Tip revision: 63c18a1640c91f861717a63b6a2b60384ee3882b authored by Amanpreet Singh on 29 September 2021, 18:58:24 UTC
[feat] FSDP support for MMF
[feat] FSDP support for MMF
Tip revision: 63c18a1
optimizers.py
# Copyright (c) Facebook, Inc. and its affiliates.
import math
from typing import Callable
import torch
from mmf.common.registry import registry
from transformers.optimization import AdamW
registry.register_optimizer("adam_w")(AdamW)
@registry.register_optimizer("adam_w_skip_params_with_zero_grad")
class AdamWSkipParamsWithZeroGrad(AdamW):
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (:obj:`Callable`, `optional`): A closure that reevaluates the model
and returns the loss.
modified from
https://github.com/huggingface/transformers/blob/d2f9cb838ec1ed7f62ddfb850dccd223e19441ad/src/transformers/optimization.py#L259-L318 # NoQA
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if p.grad.abs().sum().item() == 0:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider "
"SparseAdam instead"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = (
step_size * math.sqrt(bias_correction2) / bias_correction1
)
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
return loss