Shortcuts

Source code for ignite.handlers.early_stopping

from collections import OrderedDict
from collections.abc import Callable, Mapping
from typing import Any, cast, Literal
import warnings

from ignite.base import Serializable, ResettableHandler
from ignite.engine import Engine, Events
from ignite.utils import setup_logger

__all__ = ["EarlyStopping"]


[docs]class EarlyStopping(Serializable, ResettableHandler): """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. Args: patience: Number of events to wait if no improvement and then stop the training. score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` object, and return a score ``float``. An improvement is considered if the score is higher (for ``mode='max'``) or lower (for ``mode='min'``). trainer: Trainer engine to stop the run if no improvement. threshold: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it is a minimum increase; for ``mode='min'``, it is a minimum decrease. An improvement is only considered if the change exceeds the threshold determined by ``threshold`` and ``threshold_mode``. cumulative: If True, ``threshold`` defines the change since the last ``patience`` reset, otherwise it defines the change after the last event. Default value is False. threshold_mode: Determines whether ``threshold`` is an absolute change or a relative change. - In ``'abs'`` mode: - For ``mode='max'``: improvement if ``score > best_score + threshold`` - For ``mode='min'``: improvement if ``score < best_score - threshold`` - In ``'rel'`` mode: - For ``mode='max'``: improvement if ``score > best_score * (1 + threshold)`` - For ``mode='min'``: improvement if ``score < best_score * (1 - threshold)`` Possible values are ``"abs"`` and ``"rel"``. Default value is ``"abs"``. mode: Whether to maximize (``'max'``) or minimize (``'min'``) the score. Default is ``'max'``. Examples: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping def score_function(engine): val_loss = engine.state.metrics["nll"] return -val_loss handler = EarlyStopping( patience=10, score_function=score_function, trainer=trainer, ) # Note: the handler is attached to an *Evaluator* evaluator.add_event_handler(Events.COMPLETED, handler) .. versionchanged:: 0.6.0 Renamed ``min_delta_mode`` to ``threshold_mode``. Renamed ``min_delta`` to ``threshold``. Renamed ``cumulative_delta`` to ``cumulative``. Added :meth:`get_default_score_fn` and :meth:`get_default_event_filter` static helpers. .. versionchanged:: 0.5.4 Added `mode` parameter to support minimization in addition to maximization. Added `min_delta_mode` parameter to support both absolute and relative improvements. """ _state_dict_all_req_keys = ( "counter", "best_score", "threshold_mode", ) def __init__( self, patience: int, score_function: Callable, trainer: Engine, threshold: float = 0.0, cumulative: bool = False, threshold_mode: Literal["abs", "rel"] = "abs", mode: Literal["min", "max"] = "max", # Deprecated args for BC min_delta: float | None = None, min_delta_mode: Literal["abs", "rel"] | None = None, cumulative_delta: bool | None = None, ): if not callable(score_function): raise TypeError("Argument score_function should be a function.") if patience < 1: raise ValueError("Argument patience should be positive integer.") if not isinstance(trainer, Engine): raise TypeError("Argument trainer should be an instance of Engine.") # Backward compatibility for deprecated args if min_delta is not None: warnings.warn( "'min_delta' is deprecated and will be removed in a future version. Please use 'threshold' instead.", DeprecationWarning, stacklevel=2, ) threshold = min_delta if min_delta_mode is not None: warnings.warn( "'min_delta_mode' is deprecated and will be removed in a future version. " "Please use 'threshold_mode' instead.", DeprecationWarning, stacklevel=2, ) threshold_mode = min_delta_mode if cumulative_delta is not None: warnings.warn( "'cumulative_delta' is deprecated and will be removed in a future version. " "Please use 'cumulative' instead.", DeprecationWarning, stacklevel=2, ) cumulative = cumulative_delta if threshold < 0.0: raise ValueError("Argument threshold should not be a negative number.") if threshold_mode not in ("abs", "rel"): raise ValueError("Argument threshold_mode should be either 'abs' or 'rel'.") if mode not in ("min", "max"): raise ValueError("Argument mode should be either 'min' or 'max'.") self.score_function = score_function self.patience = patience self.threshold = threshold self.threshold_mode = threshold_mode self.cumulative = cumulative self.trainer = trainer self.counter = 0 self.best_score: float | None = None self.logger = setup_logger(__name__ + "." + self.__class__.__name__) self.mode = mode @property def min_delta(self) -> float: warnings.warn( "min_delta is deprecated and will be removed in a future version. Please use 'threshold' instead.", DeprecationWarning, stacklevel=2, ) return self.threshold @min_delta.setter def min_delta(self, value: float) -> None: warnings.warn( "min_delta is deprecated and will be removed in a future version. Please use 'threshold' instead.", DeprecationWarning, stacklevel=2, ) self.threshold = value @property def min_delta_mode(self) -> str: warnings.warn( "min_delta_mode is deprecated and will be removed in a future version. Please use 'threshold_mode' instead.", DeprecationWarning, stacklevel=2, ) return self.threshold_mode @min_delta_mode.setter def min_delta_mode(self, value: str) -> None: warnings.warn( "min_delta_mode is deprecated and will be removed in a future version. Please use 'threshold_mode' instead.", DeprecationWarning, stacklevel=2, ) self.threshold_mode = value @property def cumulative_delta(self) -> bool: warnings.warn( "cumulative_delta is deprecated and will be removed in a future version. Please use 'cumulative' instead.", DeprecationWarning, stacklevel=2, ) return self.cumulative @cumulative_delta.setter def cumulative_delta(self, value: bool) -> None: warnings.warn( "cumulative_delta is deprecated and will be removed in a future version. Please use 'cumulative' instead.", DeprecationWarning, stacklevel=2, ) self.cumulative = value def __call__(self, engine: Engine) -> None: score = self.score_function(engine) if self.best_score is None: self.best_score = score return threshold = -self.threshold if self.mode == "min" else self.threshold if self.threshold_mode == "abs": improvement_threshold = self.best_score + threshold else: improvement_threshold = self.best_score * (1 + threshold) no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold if no_improvement: if not self.cumulative: self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score) self.counter += 1 self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) if self.counter >= self.patience: self.logger.info("EarlyStopping: Stop training") self.trainer.terminate() else: self.best_score = score self.counter = 0
[docs] def reset(self) -> None: """Reset the early stopping state, including the counter and best score. .. versionadded:: 0.5.4 """ self.counter = 0 self.best_score = None
[docs] def attach( # type: ignore[override] self, engine: Engine, event: Any = Events.COMPLETED, reset_engine: Engine | None = None, reset_event: Any = Events.STARTED, *args: Any, **kwargs: Any, ) -> None: """Attaches the early stopping handler to an engine and registers its reset callback. This method will: 1. Add the early stopping evaluation logic (``self``) to ``engine`` on the given ``event``. 2. Add the ``reset`` method to ``reset_engine`` (or ``engine`` if not provided) on the given ``reset_event``. Args: engine: The engine to attach the early stopping evaluation to (typically an evaluator). event: The event on ``engine`` that triggers the early stopping check. Default is :attr:`~ignite.engine.events.Events.COMPLETED`. reset_engine: The engine to attach the reset callback to (typically the trainer). If ``None``, defaults to ``engine``. reset_event: The event on ``reset_engine`` that triggers the handler state reset. Default is :attr:`~ignite.engine.events.Events.STARTED`. .. versionadded:: 0.5.4 """ engine.add_event_handler(event, self) target_reset_engine = reset_engine or engine target_reset_engine.add_event_handler(reset_event, self.reset)
[docs] def state_dict(self) -> "OrderedDict[str, Any]": """Method returns state dict with ``counter`` and ``best_score``. Can be used to save internal state of the class. """ return OrderedDict( [ ("counter", self.counter), ("best_score", cast(float, self.best_score)), ("threshold_mode", self.threshold_mode), ] )
[docs] def load_state_dict(self, state_dict: Mapping) -> None: """Method replace internal state of the class with provided state dict data. Args: state_dict: a dict with "counter" and "best_score" keys/values. """ super().load_state_dict(state_dict) self.counter = state_dict["counter"] self.best_score = state_dict["best_score"] self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode)
[docs] @staticmethod def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable: """Helper method to build a score function from an engine metric name. The returned callable reads ``engine.state.metrics[metric_name]`` and multiplies it by ``score_sign``. Use ``score_sign=-1.0`` for error-like metrics (smaller is better) when the handler is configured with ``mode="max"`` so that decreases in the metric register as score improvements. Args: metric_name: name of the metric in ``engine.state.metrics``. score_sign: ``1.0`` (default) or ``-1.0``. For error-like metrics where smaller is better, use ``-1.0``. Returns: A callable taking an :class:`~ignite.engine.engine.Engine` and returning ``float``. Examples: .. code-block:: python from ignite.handlers import EarlyStopping # Validation accuracy: larger is better, default mode="max" score_fn = EarlyStopping.get_default_score_fn("accuracy") handler = EarlyStopping(patience=5, score_function=score_fn, trainer=trainer) # Validation loss: smaller is better, flip the sign so larger is better neg_loss_fn = EarlyStopping.get_default_score_fn("loss", -1.0) handler = EarlyStopping(patience=5, score_function=neg_loss_fn, trainer=trainer) .. versionadded:: 0.6.0 """ if score_sign not in (1.0, -1.0): raise ValueError("Argument score_sign should be 1.0 or -1.0") def wrapper(engine: Engine) -> float: return score_sign * engine.state.metrics[metric_name] return wrapper
[docs] def get_default_event_filter(self, after: int) -> Callable[[Engine, int], bool]: """Build an event filter that delays early-stopping checks until the trainer has completed at least ``after`` epochs. This implements a warmup window for early stopping without coupling the warmup logic to the handler itself, so it composes with any event the handler is attached to (epoch, iteration, custom). The filter consults the trainer's epoch counter (``self.trainer.state.epoch``) rather than the host engine's event count, so the warmup is well-defined even when the handler is attached to an evaluator that re-runs from scratch each epoch. Args: after: minimum number of trainer epochs that must have completed before the early-stopping handler is allowed to run. Must be a non-negative integer. Returns: A callable matching the ``event_filter`` signature ``(engine, event) -> bool``. Examples: .. code-block:: python from ignite.engine import Events from ignite.handlers import EarlyStopping handler = EarlyStopping(patience=5, score_function=score_fn, trainer=trainer) # Skip the first 3 trainer epochs, then enforce early stopping. evaluator.add_event_handler( Events.COMPLETED(event_filter=handler.get_default_event_filter(after=3)), handler, ) .. versionadded:: 0.6.0 """ if not isinstance(after, int) or after < 0: raise ValueError("Argument after should be a non-negative integer.") trainer = self.trainer def event_filter(engine: Engine, event: int) -> bool: # The host engine's `event` counter resets per `run()` (e.g. evaluators # called repeatedly), so look at the trainer's epoch counter instead. return trainer.state.epoch > after return event_filter

© Copyright 2026, PyTorch-Ignite Contributors. Last updated on 04/25/2026, 6:35:16 PM.

Built with Sphinx using a theme provided by Read the Docs.
×

Search Docs