1import numpy as np
2
3from torchmate.callbacks import Callback
4
5
[docs]
6class EarlyStopper(Callback):
7 """
8 Callback to implement early stopping based on a monitored metric.
9
10 This callback stops the training early if the monitored metric does not improve for a specified number
11 of epochs (patience).
12
13 Parameters:
14 monitor (str): Metric to monitor for early stopping. Defaults to ``"val_loss"``.
15 patience (int): Number of epochs with no improvement to wait before stopping. Defaults to ``3``.
16 mode (str): One of ``"min"`` or ``"max"``. The direction to monitor the metric. For example, \
17 ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"``
18 min_delta (float): Minimum change in the monitored metric to qualify as an improvement. Defauls to ``0``.
19 restore_best_state (bool): Whether to restore the best model state when early stopping. Defaults to ``False``.
20
21 **Example Usage:**
22
23 .. code-block:: python
24
25 early_stopper = EarlyStopper(
26 monitor="val_loss",
27 patience=3,
28 mode="min"
29 )
30
31 """
32
33 def __init__(self, monitor="val_loss", patience=3, mode="min", min_delta=0, restore_best_state=False):
34 super().__init__()
35 self.monitor = monitor
36 self.patience = patience
37 self.mode = mode
38 self.min_delta = min_delta
39 self.restore_best_state = restore_best_state
40 self.counter = 0
41 self.epoch_count = 0
42 self.optimum_value = np.inf if mode == "min" else -np.inf
43 self.best_model_state = None
44 if self.mode not in ["min", "max"]:
45 raise AttributeError("The mode parameter should be set to 'min' or 'max'")
46
[docs]
47 def on_epoch_end(self, trainer) -> None:
48 """Check for early stopping conditions after each epoch and if conditions are met stops training.
49
50 Args:
51 trainer (Trainer): The trainer object.
52
53 Returns:
54 None
55 """
56 self.epoch_count += 1
57
58 monitor_value = trainer.history[self.monitor][self.epoch_count - 1]
59 if self.mode == "min":
60 condition = (monitor_value + self.min_delta) < self.optimum_value
61 elif self.mode == "max":
62 condition = monitor_value > (self.optimum_value + self.min_delta)
63
64 if condition:
65 self.optimum_value = monitor_value
66 self.counter = 0
67 self.best_model_state = trainer.model.state_dict()
68 else:
69 self.counter += 1
70 if self.counter >= self.patience:
71 trainer.early_stop = True
72
73 if self.restore_best_state:
74 trainer.model.load_state_dict(self.best_model_state)
75
76 return None