Source code for torchmate.callbacks.early_stopper

 1import numpy as np
 2
 3from torchmate.callbacks import Callback
 4
 5
[docs] 6class EarlyStopper(Callback): 7 """ 8 Callback to implement early stopping based on a monitored metric. 9 10 This callback stops the training early if the monitored metric does not improve for a specified number 11 of epochs (patience). 12 13 Parameters: 14 monitor (str): Metric to monitor for early stopping. Defaults to ``"val_loss"``. 15 patience (int): Number of epochs with no improvement to wait before stopping. Defaults to ``3``. 16 mode (str): One of ``"min"`` or ``"max"``. The direction to monitor the metric. For example, \ 17 ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"`` 18 min_delta (float): Minimum change in the monitored metric to qualify as an improvement. Defauls to ``0``. 19 restore_best_state (bool): Whether to restore the best model state when early stopping. Defaults to ``False``. 20 21 **Example Usage:** 22 23 .. code-block:: python 24 25 early_stopper = EarlyStopper( 26 monitor="val_loss", 27 patience=3, 28 mode="min" 29 ) 30 31 """ 32 33 def __init__(self, monitor="val_loss", patience=3, mode="min", min_delta=0, restore_best_state=False): 34 super().__init__() 35 self.monitor = monitor 36 self.patience = patience 37 self.mode = mode 38 self.min_delta = min_delta 39 self.restore_best_state = restore_best_state 40 self.counter = 0 41 self.epoch_count = 0 42 self.optimum_value = np.inf if mode == "min" else -np.inf 43 self.best_model_state = None 44 if self.mode not in ["min", "max"]: 45 raise AttributeError("The mode parameter should be set to 'min' or 'max'") 46
[docs] 47 def on_epoch_end(self, trainer) -> None: 48 """Check for early stopping conditions after each epoch and if conditions are met stops training. 49 50 Args: 51 trainer (Trainer): The trainer object. 52 53 Returns: 54 None 55 """ 56 self.epoch_count += 1 57 58 monitor_value = trainer.history[self.monitor][self.epoch_count - 1] 59 if self.mode == "min": 60 condition = (monitor_value + self.min_delta) < self.optimum_value 61 elif self.mode == "max": 62 condition = monitor_value > (self.optimum_value + self.min_delta) 63 64 if condition: 65 self.optimum_value = monitor_value 66 self.counter = 0 67 self.best_model_state = trainer.model.state_dict() 68 else: 69 self.counter += 1 70 if self.counter >= self.patience: 71 trainer.early_stop = True 72 73 if self.restore_best_state: 74 trainer.model.load_state_dict(self.best_model_state) 75 76 return None