https://github.com/facebookresearch/pythia
Raw File
Tip revision: dabf95f523cd07e93380c6931e5140ade0f50b2f authored by Sethu Sankaran on 26 October 2021, 19:18:43 UTC
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Tip revision: dabf95f
early_stopping.py
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import torch
from mmf.utils.distributed import is_main, is_xla


class EarlyStopping:
    """
    Provides early stopping functionality. Keeps track of an early stop criteria,
    and if it doesn't improve over time restores last best performing
    parameters.
    """

    def __init__(
        self,
        model,
        checkpoint_instance,
        early_stop_criteria="total_loss",
        patience=1000,
        minimize=False,
        should_stop=True,
    ):
        self.minimize = minimize
        self.patience = patience
        self.model = model
        self.checkpoint = checkpoint_instance
        self.early_stop_criteria = early_stop_criteria

        if "val" not in self.early_stop_criteria:
            self.early_stop_criteria = f"val/{self.early_stop_criteria}"

        self.best_monitored_value = -np.inf if not minimize else np.inf
        self.best_monitored_iteration = 0
        self.best_monitored_update = 0
        self.should_stop = should_stop
        self.activated = False
        self.metric = self.early_stop_criteria

    def __call__(self, update, iteration, meter):
        """
        Method to be called everytime you need to check whether to
        early stop or not
        Arguments:
            update {number}: Current update number
            iteration {number}: Current iteration number
        Returns:
            bool -- Tells whether early stopping occurred or not
        """
        # There are operations involving synchronization downstream
        # For XLA those calls must be executed from all cores
        # Therefore we do return here in case of XLA
        if not is_main() and not is_xla():
            return False

        value = meter.meters.get(self.early_stop_criteria, None)
        if value is None:
            raise ValueError(
                "Criteria used for early stopping ({}) is not "
                "present in meter.".format(self.early_stop_criteria)
            )

        value = value.global_avg

        if isinstance(value, torch.Tensor):
            value = value.item()

        if (self.minimize and value < self.best_monitored_value) or (
            not self.minimize and value > self.best_monitored_value
        ):
            self.best_monitored_value = value
            self.best_monitored_iteration = iteration
            self.best_monitored_update = update
            self.checkpoint.save(update, iteration, update_best=True)

        elif self.best_monitored_update + self.patience < update:
            self.activated = True
            if self.should_stop is True:
                self.checkpoint.restore()
                self.checkpoint.finalize()
                return True
            else:
                return False
        else:
            self.checkpoint.save(update, iteration, update_best=False)

        return False

    def is_activated(self):
        return self.activated

    def init_from_checkpoint(self, load):
        if "best_iteration" in load:
            self.best_monitored_iteration = load["best_iteration"]

        if "best_metric_value" in load:
            self.best_monitored_value = load["best_metric_value"]

    def get_info(self):
        return {
            "best_update": self.best_monitored_update,
            "best_iteration": self.best_monitored_iteration,
            f"best_{self.metric}": f"{self.best_monitored_value:.6f}",
        }
back to top