https://github.com/facebookresearch/pythia
Raw File
Tip revision: 48c0d58529a2ea71ac03b13bc9dcd7a2d6aad388 authored by omkar on 01 July 2020, 17:58:35 UTC
Working training
Tip revision: 48c0d58
modeling.py
# Copyright (c) Facebook, Inc. and its affiliates.

from torch import nn

ACT2FN = {
    "relu": nn.ReLU,
    "sigmoid": nn.Sigmoid,
    "tanh": nn.Tanh,
    "leaky_relu": nn.LeakyReLU,
}


def get_bert_configured_parameters(module, lr=None):
    param_optimizer = list(module.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if lr is not None:
        for p in optimizer_grouped_parameters:
            p["lr"] = lr

    return optimizer_grouped_parameters


def get_optimizer_parameters_for_bert(module, config):
    # Pretraining has same LR for all of the parts
    if module.config.training_head_type == "pretraining":
        return get_bert_configured_parameters(module)

    # For finetuning setup, we have classifier
    lr = config.optimizer.params.lr
    model_config = getattr(config.model_config, config.model, {})
    finetune_lr_multiplier = getattr(model_config, "finetune_lr_multiplier", 1)
    # Finetune the bert pretrained part with finetune_lr_multiplier if it is set
    parameters = get_bert_configured_parameters(
        module.bert, lr * finetune_lr_multiplier
    )
    # Classifier will be trained on the normal lr
    parameters += get_bert_configured_parameters(module.classifier, lr)

    return parameters
back to top