1import re
2from typing import Callable, List, Optional, Type, Union
3
4import matplotlib.pyplot as plt
5import torch
6
7import torchmate
8from torchmate.callbacks import Callback
9from torchmate.utils import ProgressBar, RunningAverage, colorize_text
10
11
[docs]
12class Trainer(torch.nn.Module):
13 """Encapsulate training essentials
14
15 Args:
16 model (torch.nn.Module, required): The PyTorch model to be trained.
17 train_dataloader (torch.utils.data.DataLoader, required): DataLoader for the training dataset.
18 val_dataloader (torch.utils.data.DataLoader, required): DataLoader for the validation dataset.
19 loss_fn (torch.nn.Module, required): Loss function for training.
20 optimizer (torch.optim.Optimizer, required): Optimizer for updating model parameters.
21 num_epochs (int, optional): Number of training epochs (default is 1).
22 test_dataloader (torch.utils.data.DataLoader, optional): DataLoader for the test dataset.
23 metrics (List[callable], optional): List of metrics functions for evaluation.
24 callbacks (List[Callback], optional): List of callback functions for various stages.
25 scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler.
26 schedule_monitor (str, optional): Metric to monitor for scheduler (default is "val_loss").
27 mixed_precision (bool, optional): Whether to use mixed precision (fp16) training (default is False).
28 use_gradient_penalty (bool, optional): Whether to use gradient penalty (default is False).
29 device (str, optional): Device to use for training (default is "cpu").
30
31 Other Attributes:
32 - **history (dict):** Training history containing loss, metrics, and learning rates.
33 - **early_stop (bool):** Flag for early stopping.
34 - **update_params (bool):** Flag for updating model parameters.
35 - **accumulation_steps (int):** Number of steps for gradient accumulation during training.
36
37 Important Methods:
38 - **fit():** Train and validate the model for the specified number of epochs and return history.
39 - **evaluate():** Evaluate the model on the validation dataset and return evaluation history.
40 - **predict():** Make predictions using the model on the test dataset.
41
42 **Example usage:**
43
44 .. code-block:: python
45
46 import torch
47 import numpy as np
48
49 import os
50 import time
51
52 from torchmate.trainer import Trainer
53 from torchmate.callbacks import CSVLogger, ModelCheckpoint
54 from sklearn.model_selection import train_test_split
55
56 # Create a simple neural network model
57 class SimpleModel(torch.nn.Module):
58 def __init__(self):
59 super(SimpleModel, self).__init__()
60 self.fc1 = torch.nn.Linear(1, 1)
61
62 def forward(self, x):
63 return self.fc1(x)
64
65 # Create synthetic data
66 X = torch.tensor(np.random.rand(1000, 1), dtype=torch.float32)
67 y = 2 * X + 1 + torch.randn(1000, 1) * 0.1 # Adding some noise
68
69 # Split the data into training and validation sets
70 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
71
72 # Create DataLoader objects for training and validation
73 train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
74 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
75 val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
76 val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
77
78
79 # Create Metrics
80
81 class MSE(torch.nn.Module):
82 __name__ = 'mse'
83 def __init__(self, weight=None, size_average=True):
84 super(MSE, self).__init__()
85 def forward(self, inputs, targets):
86 inputs = inputs.view(-1)
87 targets = targets.view(-1)
88 mse = torch.mean(torch.abs(inputs - targets))
89 return mse
90
91
92 def mae(inputs, targets):
93 inputs = inputs.view(-1)
94 targets = targets.view(-1)
95 mae = torch.abs(torch.mean(inputs - targets))
96 return mae
97
98 model = SimpleModel()
99 loss_fn = torch.nn.MSELoss()
100 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
101 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
102
103 metrics = [MSE(),mae]
104
105 logdir = "logs"
106 csv_file = os.path.join(logdir,"logs.csv")
107 ckpt_dir = os.path.join(logdir,"model")
108
109 callbacks = [CSVLogger(filename=csv_file),
110 ModelCheckpoint(checkpoint_dir=ckpt_dir)
111 ]
112
113
114 device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
115 print(device)
116
117 # Create a Trainer instance with the callbacks
118 trainer = Trainer(model,
119 train_dataloader,
120 val_dataloader,
121 loss_fn,
122 optimizer,
123 num_epochs=3,
124 scheduler=scheduler,
125 metrics=metrics,
126 callbacks=callbacks,
127 device=device,
128 mixed_precision=True,
129 use_grad_penalty=True
130 )
131
132
133 # Train the model
134 history = trainer.fit()
135
136 print("_"*150)
137
138 print(pd.read_csv(csv_file))
139
140 """
141
142 def __init__(
143 self,
144 model: torch.nn.Module,
145 train_dataloader: torch.utils.data.DataLoader,
146 val_dataloader: torch.utils.data.DataLoader,
147 loss_fn: Union[Callable, torch.nn.Module],
148 optimizer: torch.optim.Optimizer,
149 num_epochs: int = 1,
150 test_dataloader: Optional[torch.utils.data.DataLoader] = None,
151 metrics: Optional[List[Callable]] = None,
152 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None,
153 scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
154 schedule_monitor: str = "val_loss",
155 mixed_precision: bool = False,
156 use_grad_penalty: bool = False,
157 device: Union[str, torch.device] = "cpu",
158 ):
159
160 super().__init__()
161 self.model = model
162 self.train_dataloader = train_dataloader
163 self.val_dataloader = val_dataloader
164 self.test_dataloader = test_dataloader
165 self.loss_fn = loss_fn
166 self.optimizer = optimizer
167 self.scheduler = scheduler
168 self.schedule_monitor = schedule_monitor
169 self.metrics = metrics
170 self.num_epochs = num_epochs
171 self.callbacks = callbacks
172 self.use_amp = mixed_precision
173 self.use_grad_penalty = use_grad_penalty
174 self.device = device
175 self.history = {}
176 self.early_stop = False
177 self.update_params = True
178 self.accumulation_steps = 4
179 self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp and self.device != "cpu")
180
181 # Initializig the log history dict
182 self.history["loss"] = []
183 self.history["val_loss"] = []
184 self.history["lr"] = []
185 if self.metrics is not None:
186 for metric in self.metrics:
187 self.history[f"{metric.__name__}"] = []
188 self.history[f"val_{metric.__name__}"] = []
189 ########################################################
190
191 # set __name__ attribute for metrics and loss. It is import for logging and printing
192 if self.metrics is not None:
193 for metric in self.metrics:
194 if not hasattr(metric, "__name__"):
195 metric.__name__ = self.camel_to_snake_case(metric.__class__.__name__)
196
197 if not hasattr(self.loss_fn, "__name__"):
198 self.loss_fn.__name__ = self.camel_to_snake_case(self.loss_fn.__class__.__name__)
199 ##########################################################################################
200
[docs]
201 def fit(self):
202 """Train the model and returns the training history.
203
204 Returns:
205 Dict : A dictionary object encapsulating the training history.
206 """
207 history = self.train_and_evaluate()
208 return history
209
[docs]
210 def evaluate(
211 self,
212 dataloader: Optional[torch.utils.data.DataLoader] = None,
213 loss_fn: Union[Callable, torch.nn.Module] = None,
214 metrics: Optional[List[Callable]] = None,
215 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None,
216 device: Optional[Union[str, torch.device]] = None,
217 ):
218 """Evaluate the model on the a dataset and returns the evaluation history.
219
220 This method provides flexibility for customized evaluation.
221
222 Args:
223 dataloader (torch.utils.data.DataLoader, optional): A PyTorch DataLoader containing the validation data.
224 If not provided, the ``self.val_dataloader`` attribute will be used. Defaults to ``None``.
225 loss_fn (Callable or torch.nn.Module, optional): A custom loss function for evaluation.
226 If not provided, the ``self.loss_fn`` attribute will be used. Defaults to ``None``.
227 metrics (List[Callable], optional): A list of custom evaluation metrics.
228 If not provided, the ``self.metrics`` attribute will be used. Defaults to ``None``.
229 callbacks (List[Callback], optional): A list of callback objects for evaluation stages.
230 If not provided, the ``self.callbacks`` attribute will be used. Defaults to ``None``.
231 device (str or torch.device, optional): The device to use for evaluation (e.g., "cpu" or "cuda").
232 If not provided, the ``self.device`` attribute will be used. Defaults to ``None``.
233
234 Returns:
235 Dict: A dictionary object containing the evaluation results.
236 """
237
238 model = self.model
239 dataloader = dataloader if dataloader else self.val_dataloader
240 callbacks = callbacks if callbacks else self.callbacks
241 loss_fn = loss_fn if loss_fn else self.loss_fn
242 metrics = metrics if metrics else self.metrics
243 device = device if device else self.device
244
245 history = self.evaluate_single_epoch(model, dataloader, loss_fn, metrics, callbacks, device)
246 return history
247
[docs]
248 def predict(
249 self,
250 test_dataloader: Optional[torch.utils.data.DataLoader] = None,
251 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None,
252 device: Optional[Union[str, torch.device]] = None,
253 ):
254 """Perform predictions on the provided test data using the trained model.
255
256 This method enables you to make predictions on a test dataset using the trained model within your `Trainer` class.
257
258 Args:
259 test_dataloader (DataLoader, optional): A PyTorch DataLoader containing the test data.
260 If not provided, the ``self.test_dataloader`` attribute will be used. Defaults to ``None``.
261 callbacks (list[Callback], optional): A list of callback objects to be executed at various stages
262 of the prediction process. If not provided, the ``self.callbacks`` attribute will be used.
263 Defaults to ``None``.
264 device (str or torch.device, optional): The device to run the prediction on (e.g., "cpu" or "cuda").
265 If not provided, the ``self.device`` attribute will be used. Defaults to ``None``.
266
267 Returns:
268 torch.Tensor: A PyTorch Tensor containing the predicted outputs for the test data.
269
270 Raises:
271 ValueError: If both ``test_dataloader`` and ``self.test_dataloader`` are None.
272
273 """
274
275 if test_dataloader is None and self.test_dataloader is None:
276 raise ValueError(
277 "Missing validation data: You must provide either a `test_dataloader` argument or set a \
278 `test_dataloader` attribute on the Trainer instance."
279 )
280
281 model = self.model
282 device = device if device else self.device
283 callbacks = callbacks if callbacks else self.callbacks
284 device = device if device else self.device
285
286 if test_dataloader:
287 self.test_dataloader = test_dataloader
288 else:
289 test_dataloader = self.test_dataloader
290
291 model.eval()
292 progress_bar = ProgressBar(total=len(test_dataloader), prefix="prediction")
293 self.execute_callbacks(self, self.callbacks, "predict_begin")
294 predictions = []
295 for batch_ix, X_test in enumerate(test_dataloader):
296 self.execute_callbacks(self, self.callbacks, "predict_batch_begin")
297 X_test = X_test[0].to(device)
298 with torch.inference_mode():
299 y_pred = self.model(X_test)
300 predictions += y_pred
301 progress_bar.update(batch_ix + 1)
302 self.execute_callbacks(self, self.callbacks, "predict_batch_end")
303 self.execute_callbacks(self, self.callbacks, "predict_end")
304 return predictions
305
306 @staticmethod
307 def camel_to_snake_case(text: str) -> str:
308 """Convert CamelCase text to snake_case."""
309 string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text)
310 string = re.sub("([a-z0-9])([A-Z])", r"\1_\2", string)
311 return string.lower()
312
313 @staticmethod
314 def gradient_norm(model, loss):
315 # Creates gradients
316 grad_params = torch.autograd.grad(outputs=loss, inputs=model.parameters(), create_graph=True)
317 # Computes the penalty term and adds it to the loss
318 grad_norm = 0
319 for grad in grad_params:
320 grad_norm += grad.pow(2).sum()
321 grad_norm = grad_norm.sqrt()
322 return grad_norm
323
324 @staticmethod
325 def execute_callbacks(trainer, callbacks=None, stage=""):
326
327 if callbacks is None:
328 return None
329
330 valid_stages = [
331 "experiment_begin",
332 "experiment_end",
333 "epoch_begin",
334 "epoch_end",
335 "train_begin",
336 "train_end",
337 "train_batch_begin",
338 "train_batch_end",
339 "val_begin",
340 "val_end",
341 "val_batch_begin",
342 "val_batch_end",
343 "predict_begin",
344 "predict_end",
345 "predict_batch_begin",
346 "predict_batch_end",
347 "backward_end",
348 ]
349
350 if stage not in valid_stages:
351 raise ValueError(f"Invalid stage name. Must be one of {valid_stages}")
352
353 for callback in callbacks:
354 method = f"on_{stage}"
355 if hasattr(callback, method):
356 callback_method = getattr(callback, method)
357 callback_method(trainer)
358 return None
359
360 def train_and_evaluate(self):
361
362 self.model.to(self.device)
363 # running_avg_dict = dict() // assigned but never used, delete it
364 History = dict()
365 History["loss"] = []
366 History["val_loss"] = []
367 History["lr"] = []
368
369 if self.metrics is not None:
370 for metric in self.metrics:
371 History[f"{metric.__name__}"] = []
372 History[f"val_{metric.__name__}"] = []
373
374 self.execute_callbacks(self, self.callbacks, "experiment_begin")
375 for epoch in range(1, self.num_epochs + 1):
376 etxt = f"Epoch {epoch}/{self.num_epochs}"
377 etxt = colorize_text(etxt, fore_tuple=(0, 0, 255), bold_text=True)
378 print(etxt, end="\n")
379 self.execute_callbacks(self, self.callbacks, "epoch_begin")
380
381 history = self.train_single_epoch(
382 self.model,
383 self.train_dataloader,
384 self.optimizer,
385 self.loss_fn,
386 self.metrics,
387 self.callbacks,
388 self.device,
389 )
390
391 val_history = self.evaluate_single_epoch(
392 self.model, self.val_dataloader, self.loss_fn, self.metrics, self.callbacks, self.device
393 )
394
395 # update history
396 for key in history.keys():
397 History[key].append(history[key])
398 for key in val_history.keys():
399 History[key].append(val_history[key])
400 self.history = History
401 #########################################################
402
403 self.execute_callbacks(self, self.callbacks, "epoch_end")
404 if self.early_stop:
405 break
406 if self.update_params: # schedule learning rate only when the parameters are updated
407 if self.scheduler is not None:
408 if self.scheduler.__class__.__name__ == "ReduceLROnPlateau":
409 self.scheduler.step(val_history[self.schedule_monitor])
410 else:
411 self.scheduler.step()
412 self.execute_callbacks(self, self.callbacks, "experiment_end")
413
414 return History
415
416 def train_single_epoch(
417 self, model, train_dataloader, optimizer, loss_fn, metrics=None, callbacks=None, device=None
418 ):
419
420 model.train()
421
422 progress_bar = ProgressBar(total=len(train_dataloader), prefix="train")
423 history = dict()
424 loss_avg = RunningAverage()
425 running_avg_dict = dict()
426
427 if metrics is not None:
428 for metric in metrics:
429 running_avg_dict[f"{metric.__name__}_avg"] = RunningAverage()
430
431 self.execute_callbacks(self, callbacks, "train_begin")
432 for batch_ix, (X_train, y_train) in enumerate(train_dataloader):
433 self.execute_callbacks(self, callbacks, "train_batch_begin")
434 X_train = X_train.to(device)
435 y_train = y_train.to(device)
436 with torch.autocast(
437 device_type=device,
438 dtype=torch.float16 if self.device != "cpu" else torch.bfloat16,
439 enabled=self.use_amp and self.device != "cpu",
440 ):
441 y_pred = model(X_train)
442 batch_loss = loss_fn(y_pred, y_train)
443 if self.accumulation_steps > 1 and not self.update_params:
444 batch_loss = batch_loss / self.accumulation_steps
445 if self.use_grad_penalty:
446 batch_loss = batch_loss + self.gradient_norm(model, batch_loss)
447
448 self.scaler.scale(batch_loss).backward() # batch_loss.backward()
449 if self.update_params:
450 self.execute_callbacks(self, callbacks, "backward_end")
451 self.scaler.step(optimizer)
452 self.scaler.update() # optimizer.step()
453 optimizer.zero_grad()
454 # update value + message
455 loss_avg.update(batch_loss.item())
456 message = f"loss: {round(loss_avg(),5)}"
457 if metrics is not None:
458 for metric in metrics:
459 running_avg_dict[f"{metric.__name__}_avg"].update(metric(y_pred, y_train).item())
460 metric_value = round(running_avg_dict[f"{metric.__name__}_avg"](), 5)
461 message += f" | {metric.__name__}: {metric_value}"
462 ###############################################
463 self.execute_callbacks(self, callbacks, "train_batch_end")
464 progress_bar.update(batch_ix + 1, message)
465 self.execute_callbacks(self, callbacks, "train_end")
466
467 # update history
468 history["lr"] = self.optimizer.param_groups[0]["lr"]
469 history["loss"] = loss_avg()
470 if metrics is not None:
471 for metric in metrics:
472 history[f"{metric.__name__}"] = running_avg_dict[f"{metric.__name__}_avg"]()
473 ####################################
474
475 return history
476
477 def evaluate_single_epoch(self, model, val_dataloader, loss_fn, metrics=None, callbacks=None, device=None):
478
479 model.eval()
480 progress_bar = ProgressBar(total=len(val_dataloader), prefix="valid")
481 val_loss_avg = RunningAverage()
482 running_avg_dict = dict()
483 history = dict()
484 prefix = "val_"
485
486 if metrics is not None:
487 for metric in metrics:
488 running_avg_dict[f"{prefix}{metric.__name__}_avg"] = RunningAverage()
489
490 self.execute_callbacks(self, callbacks, "val_begin")
491 for batch_ix, (X_val, y_val) in enumerate(val_dataloader):
492 self.execute_callbacks(self, callbacks, "val_batch_begin")
493 X_val = X_val.to(device)
494 y_val = y_val.to(device)
495 with torch.inference_mode():
496 y_pred_val = model(X_val)
497 batch_val_loss = loss_fn(y_pred_val, y_val)
498 # update value + message
499 val_loss_avg.update(batch_val_loss.item())
500 message = f"{prefix}loss: {round(val_loss_avg(), 5)}"
501 if metrics is not None:
502 for metric in metrics:
503 running_avg_dict[f"{prefix}{metric.__name__}_avg"].update(metric(y_pred_val, y_val).item())
504 metric_value = round(running_avg_dict[f"{prefix}{metric.__name__}_avg"](), 5)
505 message += f" | {prefix}{metric.__name__}: {metric_value}"
506 ################################################
507 self.execute_callbacks(self, callbacks, "val_batch_end")
508 progress_bar.update(batch_ix + 1, message)
509 self.execute_callbacks(self, callbacks, "val_end")
510
511 # update history
512 history[f"{prefix}loss"] = val_loss_avg()
513 if metrics is not None:
514 for metric in metrics:
515 history[f"{prefix}{metric.__name__}"] = running_avg_dict[f"{prefix}{metric.__name__}_avg"]()
516 ####################################
517
518 return history