https://github.com/facebookresearch/pythia
Raw File
Tip revision: 7a67d7bccebf9c971b7cbe816c24b6fa660ce0a9 authored by Amanpreet Singh on 14 October 2021, 19:31:24 UTC
[fix,chores] Lint and wandb
Tip revision: 7a67d7b
modeling.py
# Copyright (c) Facebook, Inc. and its affiliates.

import logging

from torch import nn


logger = logging.getLogger(__name__)

ACT2FN = {
    "relu": nn.ReLU,
    "sigmoid": nn.Sigmoid,
    "tanh": nn.Tanh,
    "leaky_relu": nn.LeakyReLU,
}


def get_bert_configured_parameters(module, lr=None, weight_decay=0.01):
    # module param can either be a nn.Module or in some cases can also be
    # a list of named parameters for a nn.Module
    if isinstance(module, nn.Module):
        param_optimizer = list(module.named_parameters())
    elif isinstance(module, list):
        param_optimizer = module

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if lr is not None:
        for p in optimizer_grouped_parameters:
            p["lr"] = lr

    return optimizer_grouped_parameters


def get_optimizer_parameters_for_bert(module, config):
    lr = config.optimizer.params.lr
    model_config = config.model_config.get(config.model, {})
    finetune_lr_multiplier = model_config.get("finetune_lr_multiplier", 1)

    # For pretraining or when finetune_lr_multiplier == 1, all modules will be trained
    # with default lr.
    if module.config.training_head_type == "pretraining" or finetune_lr_multiplier == 1:
        return get_bert_configured_parameters(module)

    # For non pretraining heads, where finetune_lr_multiplier != 1, all modules other
    # than classifier will be trained with (lr * finetune_lr_multiplier).
    parameters = []
    for name, submodule in module.named_children():
        if name == "classifier":
            continue
        parameters += get_bert_configured_parameters(
            submodule, lr * finetune_lr_multiplier
        )
        logger.info(f"Overriding {name} module's LR to {lr * finetune_lr_multiplier}")
    # Classifier will be trained with default lr.
    parameters += get_bert_configured_parameters(module.classifier)

    return parameters
back to top