Source code for torchmate.callbacks.model_checkpoint

  1import os
  2
  3import numpy as np
  4import torch
  5import wandb
  6
  7from torchmate.callbacks import Callback
  8
  9
[docs] 10class ModelCheckpoint(Callback): 11 """ 12 Callback to save model checkpoints. 13 14 This callback saves model checkpoints at specified intervals and under certain conditions such as best 15 validation loss or metrics. 16 17 Parameters: 18 checkpoint_dir (str): Directory to save the checkpoints. 19 monitor (str, optional): Metric to monitor for saving the best model. Defaults to ``"val_loss"``. 20 mode (str, optional): One of ``"min"`` or ``"max"``. The direction to monitor the metric.\ 21 For example, ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"``. 22 min_delta (float, optional): Minimum change in the monitored metric to qualify as an improvement. Defaults to ``0.0``. 23 save_frequency (int, optional): Frequency of saving checkpoints (epochs). Defaults to ``1``. 24 save_best_only (bool, optional): Whether to save only the best model based on the monitored metric. Defaults to ``True``. 25 save_last (bool, optional): Whether to save the model checkpoint for the last epoch. Defaults to ``True``. 26 save_state_dict_only (bool, optional): Whether to save only the model's state dictionary instead of the whole model.\ 27 Defaults to ``True``. 28 29 - If ``True``: 30 * Saves a dictionary containing `epoch`, `model_state_dict`, and `optimizer_state_dict`. 31 * Requires manually creating model and optimizer instances before loading the dictionary. 32 * Only the model's weights and biases are loaded from `model_state_dict`. 33 - If ``False``: 34 * Saves the entire model object. 35 36 **Example Usage:** 37 38 .. code-block:: python 39 40 checkpoint_callback = ModelCheckpoint( 41 checkpoint_dir="checkpoints", 42 monitor="val_loss", 43 mode="min", 44 save_best_only = False, 45 save_frequency=2, 46 ) 47 callbacks = [checkpoint_callback] 48 49 """ 50 51 def __init__( 52 self, 53 checkpoint_dir, 54 monitor="val_loss", 55 mode="min", 56 min_delta=0.0, 57 save_frequency=1, 58 save_best_only=True, 59 save_state_dict_only=True, 60 save_last=True, 61 ): 62 63 super().__init__() 64 self.checkpoint_dir = checkpoint_dir 65 self.monitor = monitor 66 self.mode = mode 67 self.min_delta = min_delta 68 self.save_frequency = save_frequency 69 self.save_best_only = save_best_only 70 self.save_state_dict_only = save_state_dict_only 71 self.save_last = save_last 72 self.epoch_count = 0 73 self.optimum_value = np.inf if mode == "min" else -np.inf 74 self.best_checkpoint_path = None 75
[docs] 76 def on_epoch_end(self, trainer): 77 self.epoch_count += 1 78 79 checkpoint_dir = os.path.join(self.checkpoint_dir, "ckpts") 80 if not os.path.exists(checkpoint_dir): 81 os.makedirs(checkpoint_dir) 82 83 if self.save_state_dict_only: 84 model = { 85 "epoch": self.epoch_count, 86 "model_state_dict": trainer.model.state_dict(), 87 "optimizer_state_dict": trainer.optimizer.state_dict(), 88 } 89 else: 90 model = trainer.model 91 92 monitor_value = trainer.history[self.monitor][self.epoch_count - 1] 93 94 if self.mode == "min": 95 condition = (monitor_value + self.min_delta) < self.optimum_value 96 elif self.mode == "max": 97 condition = monitor_value > (self.optimum_value + self.min_delta) 98 99 if condition: 100 self.optimum_value = monitor_value 101 102 if self.save_best_only and condition: 103 # Delete the previous best checkpoint if it exists 104 if self.best_checkpoint_path: 105 os.remove(self.best_checkpoint_path) 106 checkpoint_name = "model_best.pt" 107 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 108 torch.save(model, checkpoint_path) 109 self.best_checkpoint_path = checkpoint_path 110 elif not self.save_best_only: 111 if self.epoch_count % self.save_frequency == 0: 112 checkpoint_name = f"model_epoch_{self.epoch_count}_{self.monitor}_{monitor_value:.5f}.pt" 113 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 114 torch.save(model, checkpoint_path) 115 116 if self.save_last and self.epoch_count == trainer.num_epochs: 117 filepath = os.path.join(checkpoint_dir, "model_last.pt") 118 torch.save(model, filepath)