1import os
2
3import numpy as np
4import torch
5import wandb
6
7from torchmate.callbacks import Callback
8
9
[docs]
10class ModelCheckpoint(Callback):
11 """
12 Callback to save model checkpoints.
13
14 This callback saves model checkpoints at specified intervals and under certain conditions such as best
15 validation loss or metrics.
16
17 Parameters:
18 checkpoint_dir (str): Directory to save the checkpoints.
19 monitor (str, optional): Metric to monitor for saving the best model. Defaults to ``"val_loss"``.
20 mode (str, optional): One of ``"min"`` or ``"max"``. The direction to monitor the metric.\
21 For example, ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"``.
22 min_delta (float, optional): Minimum change in the monitored metric to qualify as an improvement. Defaults to ``0.0``.
23 save_frequency (int, optional): Frequency of saving checkpoints (epochs). Defaults to ``1``.
24 save_best_only (bool, optional): Whether to save only the best model based on the monitored metric. Defaults to ``True``.
25 save_last (bool, optional): Whether to save the model checkpoint for the last epoch. Defaults to ``True``.
26 save_state_dict_only (bool, optional): Whether to save only the model's state dictionary instead of the whole model.\
27 Defaults to ``True``.
28
29 - If ``True``:
30 * Saves a dictionary containing `epoch`, `model_state_dict`, and `optimizer_state_dict`.
31 * Requires manually creating model and optimizer instances before loading the dictionary.
32 * Only the model's weights and biases are loaded from `model_state_dict`.
33 - If ``False``:
34 * Saves the entire model object.
35
36 **Example Usage:**
37
38 .. code-block:: python
39
40 checkpoint_callback = ModelCheckpoint(
41 checkpoint_dir="checkpoints",
42 monitor="val_loss",
43 mode="min",
44 save_best_only = False,
45 save_frequency=2,
46 )
47 callbacks = [checkpoint_callback]
48
49 """
50
51 def __init__(
52 self,
53 checkpoint_dir,
54 monitor="val_loss",
55 mode="min",
56 min_delta=0.0,
57 save_frequency=1,
58 save_best_only=True,
59 save_state_dict_only=True,
60 save_last=True,
61 ):
62
63 super().__init__()
64 self.checkpoint_dir = checkpoint_dir
65 self.monitor = monitor
66 self.mode = mode
67 self.min_delta = min_delta
68 self.save_frequency = save_frequency
69 self.save_best_only = save_best_only
70 self.save_state_dict_only = save_state_dict_only
71 self.save_last = save_last
72 self.epoch_count = 0
73 self.optimum_value = np.inf if mode == "min" else -np.inf
74 self.best_checkpoint_path = None
75
[docs]
76 def on_epoch_end(self, trainer):
77 self.epoch_count += 1
78
79 checkpoint_dir = os.path.join(self.checkpoint_dir, "ckpts")
80 if not os.path.exists(checkpoint_dir):
81 os.makedirs(checkpoint_dir)
82
83 if self.save_state_dict_only:
84 model = {
85 "epoch": self.epoch_count,
86 "model_state_dict": trainer.model.state_dict(),
87 "optimizer_state_dict": trainer.optimizer.state_dict(),
88 }
89 else:
90 model = trainer.model
91
92 monitor_value = trainer.history[self.monitor][self.epoch_count - 1]
93
94 if self.mode == "min":
95 condition = (monitor_value + self.min_delta) < self.optimum_value
96 elif self.mode == "max":
97 condition = monitor_value > (self.optimum_value + self.min_delta)
98
99 if condition:
100 self.optimum_value = monitor_value
101
102 if self.save_best_only and condition:
103 # Delete the previous best checkpoint if it exists
104 if self.best_checkpoint_path:
105 os.remove(self.best_checkpoint_path)
106 checkpoint_name = "model_best.pt"
107 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
108 torch.save(model, checkpoint_path)
109 self.best_checkpoint_path = checkpoint_path
110 elif not self.save_best_only:
111 if self.epoch_count % self.save_frequency == 0:
112 checkpoint_name = f"model_epoch_{self.epoch_count}_{self.monitor}_{monitor_value:.5f}.pt"
113 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
114 torch.save(model, checkpoint_path)
115
116 if self.save_last and self.epoch_count == trainer.num_epochs:
117 filepath = os.path.join(checkpoint_dir, "model_last.pt")
118 torch.save(model, filepath)