1import wandb
2
3from torchmate.callbacks import Callback
4
5
[docs]
6class WandbLogger(Callback):
7 """Callback to log training metrics and visualizations to Weights & Biases (Wandb).
8
9 This callback logs training metrics, visualizations, and other experiment-related data to Wandb.
10
11 **Example Usage:**
12
13 .. code-block:: python
14
15 wandb_logger = WandbLogger()
16 callbacks = [wandb_logger]
17
18 """
19
20 def __init__(self):
21 super().__init__()
22 self.current_epoch = 0
23
[docs]
24 def on_experiment_begin(self, trainer):
25 assert wandb.run is not None, "wandb is not initialized"
26
[docs]
27 def on_epoch_end(self, trainer):
28 epoch_logs = {}
29 for key, value in trainer.history.items():
30 epoch_logs[key] = trainer.history[key][self.current_epoch]
31 wandb.log(epoch_logs)
32 self.current_epoch += 1
33 return None