1import math
2import re
3from typing import Dict, Optional, Type
4
5import matplotlib.pyplot as plt
6
7
[docs]
8class HistoryPlotter:
9 """
10 A class for plotting training and validation metrics, including loss, learning rate, and any additional \
11 metrics provided by the training history.
12
13 Args:
14 trainer (Trainer, optional): An instance of ``torchmate.trainer.Trainer`` containing the training history to be plotted.
15 history (Dict, optional): A dictionary containing the training history data, with keys such as \
16 ``loss``, ``lr``, and other metric names.
17
18 Raises:
19 ValueError: If both ``trainer`` and ``history`` are None.
20
21 Notes:
22 If both ``trainer`` and ``history`` are provided, the training history from ``trainer`` will take precedence.
23
24
25 **Example usage:**
26
27 .. code-block:: python
28
29 # Using a Trainer object:
30 trainer = Trainer(...)
31 history = trainer.fit()
32 plotter = Plotter(trainer=trainer)
33 plotter.plot_all()
34
35
36 # Using a history dictionary returned from trainer.fit()
37 trainer = Trainer(...)
38 history = trainer.fit()
39 plotter = Plotter(history=history)
40 plotter.plot_all()
41
42 # Using a history dictionary:
43 history = {"loss": [...], "lr": [...], "some_metric": [...]}
44 plotter = Plotter(history=history)
45 plotter.plot_all()
46
47 """
48
49 def __init__(self, trainer=None, history=None):
50
51 if trainer is None and history is None:
52 raise ValueError("Either 'trainer' or 'history' attribute must be provided.")
53
54 if trainer:
55 self.history = trainer.history
56 self.metrics = [metric.__name__ for metric in trainer.metrics]
57 else:
58 metrics = []
59 for k, v in history.items():
60 if k not in ["Epoch", "loss", "lr"] and not k.startswith("val_"):
61 metrics.append(k)
62 self.metrics = metrics if len(metrics) > 0 else None
63 self.history = history
64
[docs]
65 @staticmethod
66 def format_to_space_capitalized(text):
67 """
68 Convert a string to space-separated capitalized words, handling various formats.
69
70 Args:
71 text: A string in any format (snake_case, camelCase, space-separated, etc.).
72
73 Returns:
74 A string with words separated by spaces and capitalized.
75 """
76 text = re.sub("_", " ", text)
77 text = re.sub(r"(?<!^)([A-Z][a-z]*)", r" \1", text)
78 text = " ".join(text.split())
79 text = text.title()
80 return text
81
[docs]
82 def plot_all(self):
83 """
84 Plot loss, learning rate and all available metrics.
85 """
86 if self.metrics is not None:
87 num_col = 2
88 num_row = math.ceil((len(self.metrics) + 2) / 2)
89 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5))
90 ax = ax.ravel()
91 else:
92 num_row = 1
93 num_col = 2
94 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5))
95 ax = ax.ravel()
96
97 ax[0].plot([None] + self.history["loss"], "o-")
98 ax[0].plot([None] + self.history["val_loss"], "o-")
99 ax[0].legend(["Training", "Validation"], loc=0)
100 ax[0].set_title("Training & Validation Loss", fontsize=20)
101 ax[0].set_xlabel("Epoch", fontsize=16)
102 ax[0].set_ylabel("Loss", fontsize=16)
103 x_ticks = list(range(1, len(self.history["loss"]) + 1))
104 ax[0].set_xticks(x_ticks)
105 ax[0].grid(True)
106
107 ax[1].plot([None] + self.history["lr"], "o-")
108 ax[1].set_title("Learning Rate", fontsize=20)
109 ax[1].set_xlabel("Epoch", fontsize=16)
110 ax[1].set_ylabel("Learning Rate", fontsize=16)
111 x_ticks = list(range(1, len(self.history["lr"]) + 1))
112 ax[1].set_xticks(x_ticks)
113 ax[1].grid(True)
114
115 if self.metrics is not None:
116 for ix, metric in enumerate(self.metrics):
117 ax[ix + 2].plot([None] + self.history[metric], "o-")
118 ax[ix + 2].plot([None] + self.history[f"val_{metric}"], "o-")
119 ax[ix + 2].legend(["Training", "Validation"], loc=0)
120 metric_name = self.format_to_space_capitalized(metric)
121 ax[ix + 2].set_title(f"Training & Validation {metric_name}", fontsize=20)
122 ax[ix + 2].set_xlabel("Epoch", fontsize=16)
123 ax[ix + 2].set_ylabel(f"{metric_name}", fontsize=16)
124 x_ticks = list(range(1, len(self.history[metric]) + 1))
125 ax[ix + 2].set_xticks(x_ticks)
126 ax[ix + 2].grid(True)
127
128 fig.tight_layout()
129 plt.show()
130 return None
131
[docs]
132 def plot_metrics(self):
133 """
134 Plot the training and validation metrics.
135 """
136 if self.metrics is None:
137 print("There are no metrics to plot!")
138 return None
139
140 num_col = 2
141 num_row = math.ceil((len(self.metrics)) / 2)
142 fig, ax = plt.subplots(num_row, num_col, figsize=(16, num_row * 5))
143 ax = ax.ravel()
144
145 for ix, metric in enumerate(self.metrics):
146 ax[ix].plot([None] + self.history[metric], "o-")
147 ax[ix].plot([None] + self.history[f"val_{metric}"], "o-")
148 ax[ix].legend(["Training", "Validation"], loc=0)
149 metric_name = self.format_to_space_capitalized(metric)
150 ax[ix].set_title(f"Training & Validation {metric_name}", fontsize=20)
151 ax[ix].set_xlabel("Epoch", fontsize=16)
152 ax[ix].set_ylabel(f"{metric_name}", fontsize=16)
153 x_ticks = list(range(1, len(self.history[metric]) + 1))
154 ax[ix].set_xticks(x_ticks)
155 ax[ix].grid(True)
156 fig.tight_layout()
157 plt.show()
158 return None
159
[docs]
160 def plot_loss(self):
161 """
162 Plot the training and validation loss.
163 """
164 fig, ax = plt.subplots(1, 1, figsize=(6, 4))
165 ax.plot([None] + self.history["loss"], "o-")
166 ax.plot([None] + self.history["val_loss"], "o-")
167 ax.legend(["Training", "Validation"], loc=0)
168 ax.set_title("Training & Validation Loss", fontsize=16)
169 ax.set_xlabel("Epoch", fontsize=12)
170 ax.set_ylabel("Loss", fontsize=12)
171 x_ticks = list(range(1, len(self.history["loss"]) + 1))
172 ax.set_xticks(x_ticks)
173 ax.grid(True)
174 fig.tight_layout()
175 plt.show()
176 return None
177
[docs]
178 def plot_lr(self):
179 """
180 Plot the learning rate over epochs.
181 """
182 fig, ax = plt.subplots(1, 1, figsize=(6, 4))
183 ax.plot([None] + self.history["lr"], "o-")
184 ax.set_title("Learning Rate", fontsize=16)
185 ax.set_xlabel("Epoch", fontsize=12)
186 ax.set_ylabel("Learning Rate", fontsize=12)
187 x_ticks = list(range(1, len(self.history["lr"]) + 1))
188 ax.set_xticks(x_ticks)
189 ax.grid(True)
190 fig.tight_layout()
191 plt.show()
192 return None