Source code for torchmate.callbacks.wandb_model_checkpoint

  1import os
  2
  3import numpy as np
  4import torch
  5import wandb
  6
  7from torchmate.callbacks import Callback
  8
  9
[docs] 10class WandbModelCheckpoint(Callback): 11 """ 12 Callback to save model checkpoints using Weights and Biases (wandb). 13 14 This callback saves model checkpoints in wandb cloud at specified intervals and under certain conditions such as best 15 validation loss or metrics. 16 17 Parameters: 18 checkpoint_dir (str): Directory to save the checkpoints. 19 monitor (str, optional): Metric to monitor for saving the best model. Defaults to ``"val_loss"``. 20 mode (str, optional): One of ``"min"`` or ``"max"``. The direction to monitor the metric.\ 21 For example, ``"min"`` for loss, ``"max"`` for accuracy. Defaults to ``"min"``. 22 min_delta (float, optional): Minimum change in the monitored metric to qualify as an improvement. Defaults to ``0.0``. 23 save_frequency (int, optional): Frequency of saving checkpoints (epochs). Defaults to ``1``. 24 save_best_only (bool, optional): Whether to save only the best model based on the monitored metric. Defaults to ``True``. 25 save_last (bool, optional): Whether to save the model checkpoint for the last epoch. Defaults to ``True``. 26 save_state_dict_only (bool, optional): Whether to save only the model's state dictionary instead of the whole model.\ 27 Defaults to ``True``. 28 29 - If ``True``: 30 * Saves a dictionary containing `epoch`, `model_state_dict`, and `optimizer_state_dict`. 31 * Requires manually creating model and optimizer instances before loading the dictionary. 32 * Only the model's weights and biases are loaded from `model_state_dict`. 33 - If ``False``: 34 * Saves the entire model object. 35 36 **Example Usage:** 37 38 .. code-block:: python 39 40 wandb_checkpoint_callback = WandbModelCheckpoint( 41 checkpoint_dir="checkpoints", 42 monitor="val_loss", 43 mode="min", 44 save_best_only = False, 45 save_frequency=2, 46 ) 47 callbacks = [wandb_checkpoint_callback] 48 49 """ 50 51 def __init__( 52 self, 53 checkpoint_dir, 54 monitor="val_loss", 55 mode="min", 56 min_delta=0, 57 save_frequency=1, 58 save_best_only=True, 59 save_state_dict_only=True, 60 save_last=True, 61 ): 62 63 super().__init__() 64 self.checkpoint_dir = checkpoint_dir 65 self.monitor = monitor 66 self.mode = mode 67 self.min_delta = min_delta 68 self.save_frequency = save_frequency 69 self.save_best_only = save_best_only 70 self.save_state_dict_only = save_state_dict_only 71 self.save_last = save_last 72 self.epoch_count = 0 73 self.optimum_value = np.inf if mode == "min" else -np.inf 74 self.best_checkpoint_path = None 75
[docs] 76 def on_experiment_begin(self, trainer): 77 assert wandb.run is not None, "wandb is not initialized"
78
[docs] 79 def on_epoch_end(self, trainer): 80 self.epoch_count += 1 81 82 checkpoint_dir = os.path.join(self.checkpoint_dir, "ckpts") 83 if not os.path.exists(checkpoint_dir): 84 os.makedirs(checkpoint_dir) 85 86 if self.save_state_dict_only: 87 model = { 88 "epoch": self.epoch_count, 89 "model_state_dict": trainer.model.state_dict(), 90 "optimizer_state_dict": trainer.optimizer.state_dict(), 91 } 92 else: 93 model = trainer.model 94 95 monitor_value = trainer.history[self.monitor][self.epoch_count - 1] 96 97 if self.mode == "min": 98 condition = (monitor_value + self.min_delta) < self.optimum_value 99 elif self.mode == "max": 100 condition = monitor_value > (self.optimum_value + self.min_delta) 101 102 if condition: 103 self.optimum_value = monitor_value 104 105 if self.save_best_only and condition: 106 # Delete the previous best checkpoint if it exists 107 if self.best_checkpoint_path: 108 os.remove(self.best_checkpoint_path) 109 checkpoint_name = "model_best.pt" 110 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 111 torch.save(model, checkpoint_path) 112 wandb.save(checkpoint_path, base_path=self.checkpoint_dir, policy="live") 113 self.best_checkpoint_path = checkpoint_path 114 elif not self.save_best_only: 115 if self.epoch_count % self.save_frequency == 0: 116 checkpoint_name = f"model_epoch_{self.epoch_count}_{self.monitor}_{monitor_value:.5f}.pt" 117 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 118 torch.save(model, checkpoint_path) 119 wandb.save(checkpoint_path, base_path=self.checkpoint_dir, policy="live") 120 121 if self.save_last and self.epoch_count == trainer.num_epochs: 122 filepath = os.path.join(checkpoint_dir, "model_last.pt") 123 torch.save(model, filepath) 124 wandb.save(filepath, base_path=self.checkpoint_dir, policy="live")