Source code for torchmate.utils.history_plotter

  1import math
  2import re
  3from typing import Dict, Optional, Type
  4
  5import matplotlib.pyplot as plt
  6
  7
[docs] 8class HistoryPlotter: 9 """ 10 A class for plotting training and validation metrics, including loss, learning rate, and any additional \ 11 metrics provided by the training history. 12 13 Args: 14 trainer (Trainer, optional): An instance of ``torchmate.trainer.Trainer`` containing the training history to be plotted. 15 history (Dict, optional): A dictionary containing the training history data, with keys such as \ 16 ``loss``, ``lr``, and other metric names. 17 18 Raises: 19 ValueError: If both ``trainer`` and ``history`` are None. 20 21 Notes: 22 If both ``trainer`` and ``history`` are provided, the training history from ``trainer`` will take precedence. 23 24 25 **Example usage:** 26 27 .. code-block:: python 28 29 # Using a Trainer object: 30 trainer = Trainer(...) 31 history = trainer.fit() 32 plotter = Plotter(trainer=trainer) 33 plotter.plot_all() 34 35 36 # Using a history dictionary returned from trainer.fit() 37 trainer = Trainer(...) 38 history = trainer.fit() 39 plotter = Plotter(history=history) 40 plotter.plot_all() 41 42 # Using a history dictionary: 43 history = {"loss": [...], "lr": [...], "some_metric": [...]} 44 plotter = Plotter(history=history) 45 plotter.plot_all() 46 47 """ 48 49 def __init__(self, trainer=None, history=None): 50 51 if trainer is None and history is None: 52 raise ValueError("Either 'trainer' or 'history' attribute must be provided.") 53 54 if trainer: 55 self.history = trainer.history 56 self.metrics = [metric.__name__ for metric in trainer.metrics] 57 else: 58 metrics = [] 59 for k, v in history.items(): 60 if k not in ["Epoch", "loss", "lr"] and not k.startswith("val_"): 61 metrics.append(k) 62 self.metrics = metrics if len(metrics) > 0 else None 63 self.history = history 64
[docs] 65 @staticmethod 66 def format_to_space_capitalized(text): 67 """ 68 Convert a string to space-separated capitalized words, handling various formats. 69 70 Args: 71 text: A string in any format (snake_case, camelCase, space-separated, etc.). 72 73 Returns: 74 A string with words separated by spaces and capitalized. 75 """ 76 text = re.sub("_", " ", text) 77 text = re.sub(r"(?<!^)([A-Z][a-z]*)", r" \1", text) 78 text = " ".join(text.split()) 79 text = text.title() 80 return text
81
[docs] 82 def plot_all(self): 83 """ 84 Plot loss, learning rate and all available metrics. 85 """ 86 if self.metrics is not None: 87 num_col = 2 88 num_row = math.ceil((len(self.metrics) + 2) / 2) 89 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5)) 90 ax = ax.ravel() 91 else: 92 num_row = 1 93 num_col = 2 94 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5)) 95 ax = ax.ravel() 96 97 ax[0].plot([None] + self.history["loss"], "o-") 98 ax[0].plot([None] + self.history["val_loss"], "o-") 99 ax[0].legend(["Training", "Validation"], loc=0) 100 ax[0].set_title("Training & Validation Loss", fontsize=20) 101 ax[0].set_xlabel("Epoch", fontsize=16) 102 ax[0].set_ylabel("Loss", fontsize=16) 103 x_ticks = list(range(1, len(self.history["loss"]) + 1)) 104 ax[0].set_xticks(x_ticks) 105 ax[0].grid(True) 106 107 ax[1].plot([None] + self.history["lr"], "o-") 108 ax[1].set_title("Learning Rate", fontsize=20) 109 ax[1].set_xlabel("Epoch", fontsize=16) 110 ax[1].set_ylabel("Learning Rate", fontsize=16) 111 x_ticks = list(range(1, len(self.history["lr"]) + 1)) 112 ax[1].set_xticks(x_ticks) 113 ax[1].grid(True) 114 115 if self.metrics is not None: 116 for ix, metric in enumerate(self.metrics): 117 ax[ix + 2].plot([None] + self.history[metric], "o-") 118 ax[ix + 2].plot([None] + self.history[f"val_{metric}"], "o-") 119 ax[ix + 2].legend(["Training", "Validation"], loc=0) 120 metric_name = self.format_to_space_capitalized(metric) 121 ax[ix + 2].set_title(f"Training & Validation {metric_name}", fontsize=20) 122 ax[ix + 2].set_xlabel("Epoch", fontsize=16) 123 ax[ix + 2].set_ylabel(f"{metric_name}", fontsize=16) 124 x_ticks = list(range(1, len(self.history[metric]) + 1)) 125 ax[ix + 2].set_xticks(x_ticks) 126 ax[ix + 2].grid(True) 127 128 fig.tight_layout() 129 plt.show() 130 return None
131
[docs] 132 def plot_metrics(self): 133 """ 134 Plot the training and validation metrics. 135 """ 136 if self.metrics is None: 137 print("There are no metrics to plot!") 138 return None 139 140 num_col = 2 141 num_row = math.ceil((len(self.metrics)) / 2) 142 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5)) 143 ax = ax.ravel() 144 145 for ix, metric in enumerate(self.metrics): 146 ax[ix].plot([None] + self.history[metric], "o-") 147 ax[ix].plot([None] + self.history[f"val_{metric}"], "o-") 148 ax[ix].legend(["Training", "Validation"], loc=0) 149 metric_name = self.format_to_space_capitalized(metric) 150 ax[ix].set_title(f"Training & Validation {metric_name}", fontsize=20) 151 ax[ix].set_xlabel("Epoch", fontsize=16) 152 ax[ix].set_ylabel(f"{metric_name}", fontsize=16) 153 x_ticks = list(range(1, len(self.history[metric]) + 1)) 154 ax[ix].set_xticks(x_ticks) 155 ax[ix].grid(True) 156 fig.tight_layout() 157 plt.show() 158 return None
159
[docs] 160 def plot_loss(self): 161 """ 162 Plot the training and validation loss. 163 """ 164 fig, ax = plt.subplots(1, 1, figsize=(6, 4)) 165 ax.plot([None] + self.history["loss"], "o-") 166 ax.plot([None] + self.history["val_loss"], "o-") 167 ax.legend(["Training", "Validation"], loc=0) 168 ax.set_title("Training & Validation Loss", fontsize=16) 169 ax.set_xlabel("Epoch", fontsize=12) 170 ax.set_ylabel("Loss", fontsize=12) 171 x_ticks = list(range(1, len(self.history["loss"]) + 1)) 172 ax.set_xticks(x_ticks) 173 ax.grid(True) 174 fig.tight_layout() 175 plt.show() 176 return None
177
[docs] 178 def plot_lr(self): 179 """ 180 Plot the learning rate over epochs. 181 """ 182 fig, ax = plt.subplots(1, 1, figsize=(6, 4)) 183 ax.plot([None] + self.history["lr"], "o-") 184 ax.set_title("Learning Rate", fontsize=16) 185 ax.set_xlabel("Epoch", fontsize=12) 186 ax.set_ylabel("Learning Rate", fontsize=12) 187 x_ticks = list(range(1, len(self.history["lr"]) + 1)) 188 ax.set_xticks(x_ticks) 189 ax.grid(True) 190 fig.tight_layout() 191 plt.show() 192 return None