https://github.com/facebookresearch/pythia
Tip revision: 7a67d7bccebf9c971b7cbe816c24b6fa660ce0a9 authored by Amanpreet Singh on 14 October 2021, 19:31:24 UTC
[fix,chores] Lint and wandb
[fix,chores] Lint and wandb
Tip revision: 7a67d7b
modeling.py
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from torch import nn
logger = logging.getLogger(__name__)
ACT2FN = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
"leaky_relu": nn.LeakyReLU,
}
def get_bert_configured_parameters(module, lr=None, weight_decay=0.01):
# module param can either be a nn.Module or in some cases can also be
# a list of named parameters for a nn.Module
if isinstance(module, nn.Module):
param_optimizer = list(module.named_parameters())
elif isinstance(module, list):
param_optimizer = module
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": weight_decay,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
if lr is not None:
for p in optimizer_grouped_parameters:
p["lr"] = lr
return optimizer_grouped_parameters
def get_optimizer_parameters_for_bert(module, config):
lr = config.optimizer.params.lr
model_config = config.model_config.get(config.model, {})
finetune_lr_multiplier = model_config.get("finetune_lr_multiplier", 1)
# For pretraining or when finetune_lr_multiplier == 1, all modules will be trained
# with default lr.
if module.config.training_head_type == "pretraining" or finetune_lr_multiplier == 1:
return get_bert_configured_parameters(module)
# For non pretraining heads, where finetune_lr_multiplier != 1, all modules other
# than classifier will be trained with (lr * finetune_lr_multiplier).
parameters = []
for name, submodule in module.named_children():
if name == "classifier":
continue
parameters += get_bert_configured_parameters(
submodule, lr * finetune_lr_multiplier
)
logger.info(f"Overriding {name} module's LR to {lr * finetune_lr_multiplier}")
# Classifier will be trained with default lr.
parameters += get_bert_configured_parameters(module.classifier)
return parameters