1import os
2
3import numpy as np
4import torch
5import wandb
6
7from torchmate.callbacks import Callback
8
9
[docs]
10class WandbModelCheckpoint(Callback):
11 """
12 Callback to save model checkpoints using Weights and Biases (wandb).
13
14 This callback saves model checkpoints in wandb cloud at specified intervals and under certain conditions such as best
15 validation loss or metrics.
16
17 Parameters:
18 checkpoint_dir (str): Directory to save the checkpoints.
19 monitor (str, optional): Metric to monitor for saving the best model. Defaults to ``"val_loss"``.
20 mode (str, optional): One of ``"min"`` or ``"max"``. The direction to monitor the metric.\
21 For example, ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"``.
22 min_delta (float, optional): Minimum change in the monitored metric to qualify as an improvement. Defaults to ``0.0``.
23 save_frequency (int, optional): Frequency of saving checkpoints (epochs). Defaults to ``1``.
24 save_best_only (bool, optional): Whether to save only the best model based on the monitored metric. Defaults to ``True``.
25 save_last (bool, optional): Whether to save the model checkpoint for the last epoch. Defaults to ``True``.
26 save_state_dict_only (bool, optional): Whether to save only the model's state dictionary instead of the whole model.\
27 Defaults to ``True``.
28
29 - If ``True``:
30 * Saves a dictionary containing `epoch`, `model_state_dict`, and `optimizer_state_dict`.
31 * Requires manually creating model and optimizer instances before loading the dictionary.
32 * Only the model's weights and biases are loaded from `model_state_dict`.
33 - If ``False``:
34 * Saves the entire model object.
35
36 **Example Usage:**
37
38 .. code-block:: python
39
40 wandb_checkpoint_callback = WandbModelCheckpoint(
41 checkpoint_dir="checkpoints",
42 monitor="val_loss",
43 mode="min",
44 save_best_only = False,
45 save_frequency=2,
46 )
47 callbacks = [wandb_checkpoint_callback]
48
49 """
50
51 def __init__(
52 self,
53 checkpoint_dir,
54 monitor="val_loss",
55 mode="min",
56 min_delta=0,
57 save_frequency=1,
58 save_best_only=True,
59 save_state_dict_only=True,
60 save_last=True,
61 ):
62
63 super().__init__()
64 self.checkpoint_dir = checkpoint_dir
65 self.monitor = monitor
66 self.mode = mode
67 self.min_delta = min_delta
68 self.save_frequency = save_frequency
69 self.save_best_only = save_best_only
70 self.save_state_dict_only = save_state_dict_only
71 self.save_last = save_last
72 self.epoch_count = 0
73 self.optimum_value = np.inf if mode == "min" else -np.inf
74 self.best_checkpoint_path = None
75
[docs]
76 def on_experiment_begin(self, trainer):
77 assert wandb.run is not None, "wandb is not initialized"
78
[docs]
79 def on_epoch_end(self, trainer):
80 self.epoch_count += 1
81
82 checkpoint_dir = os.path.join(self.checkpoint_dir, "ckpts")
83 if not os.path.exists(checkpoint_dir):
84 os.makedirs(checkpoint_dir)
85
86 if self.save_state_dict_only:
87 model = {
88 "epoch": self.epoch_count,
89 "model_state_dict": trainer.model.state_dict(),
90 "optimizer_state_dict": trainer.optimizer.state_dict(),
91 }
92 else:
93 model = trainer.model
94
95 monitor_value = trainer.history[self.monitor][self.epoch_count - 1]
96
97 if self.mode == "min":
98 condition = (monitor_value + self.min_delta) < self.optimum_value
99 elif self.mode == "max":
100 condition = monitor_value > (self.optimum_value + self.min_delta)
101
102 if condition:
103 self.optimum_value = monitor_value
104
105 if self.save_best_only and condition:
106 # Delete the previous best checkpoint if it exists
107 if self.best_checkpoint_path:
108 os.remove(self.best_checkpoint_path)
109 checkpoint_name = "model_best.pt"
110 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
111 torch.save(model, checkpoint_path)
112 wandb.save(checkpoint_path, base_path=self.checkpoint_dir, policy="live")
113 self.best_checkpoint_path = checkpoint_path
114 elif not self.save_best_only:
115 if self.epoch_count % self.save_frequency == 0:
116 checkpoint_name = f"model_epoch_{self.epoch_count}_{self.monitor}_{monitor_value:.5f}.pt"
117 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
118 torch.save(model, checkpoint_path)
119 wandb.save(checkpoint_path, base_path=self.checkpoint_dir, policy="live")
120
121 if self.save_last and self.epoch_count == trainer.num_epochs:
122 filepath = os.path.join(checkpoint_dir, "model_last.pt")
123 torch.save(model, filepath)
124 wandb.save(filepath, base_path=self.checkpoint_dir, policy="live")