Source code for torchmate.trainer.trainer

  1import re
  2from typing import Callable, List, Optional, Type, Union
  3
  4import matplotlib.pyplot as plt
  5import torch
  6
  7import torchmate
  8from torchmate.callbacks import Callback
  9from torchmate.utils import ProgressBar, RunningAverage, colorize_text
 10
 11
[docs] 12class Trainer(torch.nn.Module): 13 """Encapsulate training essentials 14 15 Args: 16 model (torch.nn.Module, required): The PyTorch model to be trained. 17 train_dataloader (torch.utils.data.DataLoader, required): DataLoader for the training dataset. 18 val_dataloader (torch.utils.data.DataLoader, required): DataLoader for the validation dataset. 19 loss_fn (torch.nn.Module, required): Loss function for training. 20 optimizer (torch.optim.Optimizer, required): Optimizer for updating model parameters. 21 num_epochs (int, optional): Number of training epochs (default is 1). 22 test_dataloader (torch.utils.data.DataLoader, optional): DataLoader for the test dataset. 23 metrics (List[callable], optional): List of metrics functions for evaluation. 24 callbacks (List[Callback], optional): List of callback functions for various stages. 25 scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler. 26 schedule_monitor (str, optional): Metric to monitor for scheduler (default is "val_loss"). 27 mixed_precision (bool, optional): Whether to use mixed precision (fp16) training (default is False). 28 use_gradient_penalty (bool, optional): Whether to use gradient penalty (default is False). 29 device (str, optional): Device to use for training (default is "cpu"). 30 31 Other Attributes: 32 - **history (dict):** Training history containing loss, metrics, and learning rates. 33 - **early_stop (bool):** Flag for early stopping. 34 - **update_params (bool):** Flag for updating model parameters. 35 - **accumulation_steps (int):** Number of steps for gradient accumulation during training. 36 37 Important Methods: 38 - **fit():** Train and validate the model for the specified number of epochs and return history. 39 - **evaluate():** Evaluate the model on the validation dataset and return evaluation history. 40 - **predict():** Make predictions using the model on the test dataset. 41 42 **Example usage:** 43 44 .. code-block:: python 45 46 import torch 47 import numpy as np 48 49 import os 50 import time 51 52 from torchmate.trainer import Trainer 53 from torchmate.callbacks import CSVLogger, ModelCheckpoint 54 from sklearn.model_selection import train_test_split 55 56 # Create a simple neural network model 57 class SimpleModel(torch.nn.Module): 58 def __init__(self): 59 super(SimpleModel, self).__init__() 60 self.fc1 = torch.nn.Linear(1, 1) 61 62 def forward(self, x): 63 return self.fc1(x) 64 65 # Create synthetic data 66 X = torch.tensor(np.random.rand(1000, 1), dtype=torch.float32) 67 y = 2 * X + 1 + torch.randn(1000, 1) * 0.1 # Adding some noise 68 69 # Split the data into training and validation sets 70 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) 71 72 # Create DataLoader objects for training and validation 73 train_dataset = torch.utils.data.TensorDataset(X_train, y_train) 74 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True) 75 val_dataset = torch.utils.data.TensorDataset(X_val, y_val) 76 val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False) 77 78 79 # Create Metrics 80 81 class MSE(torch.nn.Module): 82 __name__ = 'mse' 83 def __init__(self, weight=None, size_average=True): 84 super(MSE, self).__init__() 85 def forward(self, inputs, targets): 86 inputs = inputs.view(-1) 87 targets = targets.view(-1) 88 mse = torch.mean(torch.abs(inputs - targets)) 89 return mse 90 91 92 def mae(inputs, targets): 93 inputs = inputs.view(-1) 94 targets = targets.view(-1) 95 mae = torch.abs(torch.mean(inputs - targets)) 96 return mae 97 98 model = SimpleModel() 99 loss_fn = torch.nn.MSELoss() 100 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 101 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) 102 103 metrics = [MSE(),mae] 104 105 logdir = "logs" 106 csv_file = os.path.join(logdir,"logs.csv") 107 ckpt_dir = os.path.join(logdir,"model") 108 109 callbacks = [CSVLogger(filename=csv_file), 110 ModelCheckpoint(checkpoint_dir=ckpt_dir) 111 ] 112 113 114 device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 115 print(device) 116 117 # Create a Trainer instance with the callbacks 118 trainer = Trainer(model, 119 train_dataloader, 120 val_dataloader, 121 loss_fn, 122 optimizer, 123 num_epochs=3, 124 scheduler=scheduler, 125 metrics=metrics, 126 callbacks=callbacks, 127 device=device, 128 mixed_precision=True, 129 use_grad_penalty=True 130 ) 131 132 133 # Train the model 134 history = trainer.fit() 135 136 print("_"*150) 137 138 print(pd.read_csv(csv_file)) 139 140 """ 141 142 def __init__( 143 self, 144 model: torch.nn.Module, 145 train_dataloader: torch.utils.data.DataLoader, 146 val_dataloader: torch.utils.data.DataLoader, 147 loss_fn: Union[Callable, torch.nn.Module], 148 optimizer: torch.optim.Optimizer, 149 num_epochs: int = 1, 150 test_dataloader: Optional[torch.utils.data.DataLoader] = None, 151 metrics: Optional[List[Callable]] = None, 152 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None, 153 scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 154 schedule_monitor: str = "val_loss", 155 mixed_precision: bool = False, 156 use_grad_penalty: bool = False, 157 device: Union[str, torch.device] = "cpu", 158 ): 159 160 super().__init__() 161 self.model = model 162 self.train_dataloader = train_dataloader 163 self.val_dataloader = val_dataloader 164 self.test_dataloader = test_dataloader 165 self.loss_fn = loss_fn 166 self.optimizer = optimizer 167 self.scheduler = scheduler 168 self.schedule_monitor = schedule_monitor 169 self.metrics = metrics 170 self.num_epochs = num_epochs 171 self.callbacks = callbacks 172 self.use_amp = mixed_precision 173 self.use_grad_penalty = use_grad_penalty 174 self.device = device 175 self.history = {} 176 self.early_stop = False 177 self.update_params = True 178 self.accumulation_steps = 4 179 self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp and self.device != "cpu") 180 181 # Initializig the log history dict 182 self.history["loss"] = [] 183 self.history["val_loss"] = [] 184 self.history["lr"] = [] 185 if self.metrics is not None: 186 for metric in self.metrics: 187 self.history[f"{metric.__name__}"] = [] 188 self.history[f"val_{metric.__name__}"] = [] 189 ######################################################## 190 191 # set __name__ attribute for metrics and loss. It is import for logging and printing 192 if self.metrics is not None: 193 for metric in self.metrics: 194 if not hasattr(metric, "__name__"): 195 metric.__name__ = self.camel_to_snake_case(metric.__class__.__name__) 196 197 if not hasattr(self.loss_fn, "__name__"): 198 self.loss_fn.__name__ = self.camel_to_snake_case(self.loss_fn.__class__.__name__) 199 ########################################################################################## 200
[docs] 201 def fit(self): 202 """Train the model and returns the training history. 203 204 Returns: 205 Dict : A dictionary object encapsulating the training history. 206 """ 207 history = self.train_and_evaluate() 208 return history
209
[docs] 210 def evaluate( 211 self, 212 dataloader: Optional[torch.utils.data.DataLoader] = None, 213 loss_fn: Union[Callable, torch.nn.Module] = None, 214 metrics: Optional[List[Callable]] = None, 215 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None, 216 device: Optional[Union[str, torch.device]] = None, 217 ): 218 """Evaluate the model on the a dataset and returns the evaluation history. 219 220 This method provides flexibility for customized evaluation. 221 222 Args: 223 dataloader (torch.utils.data.DataLoader, optional): A PyTorch DataLoader containing the validation data. 224 If not provided, the ``self.val_dataloader`` attribute will be used. Defaults to ``None``. 225 loss_fn (Callable or torch.nn.Module, optional): A custom loss function for evaluation. 226 If not provided, the ``self.loss_fn`` attribute will be used. Defaults to ``None``. 227 metrics (List[Callable], optional): A list of custom evaluation metrics. 228 If not provided, the ``self.metrics`` attribute will be used. Defaults to ``None``. 229 callbacks (List[Callback], optional): A list of callback objects for evaluation stages. 230 If not provided, the ``self.callbacks`` attribute will be used. Defaults to ``None``. 231 device (str or torch.device, optional): The device to use for evaluation (e.g., "cpu" or "cuda"). 232 If not provided, the ``self.device`` attribute will be used. Defaults to ``None``. 233 234 Returns: 235 Dict: A dictionary object containing the evaluation results. 236 """ 237 238 model = self.model 239 dataloader = dataloader if dataloader else self.val_dataloader 240 callbacks = callbacks if callbacks else self.callbacks 241 loss_fn = loss_fn if loss_fn else self.loss_fn 242 metrics = metrics if metrics else self.metrics 243 device = device if device else self.device 244 245 history = self.evaluate_single_epoch(model, dataloader, loss_fn, metrics, callbacks, device) 246 return history
247
[docs] 248 def predict( 249 self, 250 test_dataloader: Optional[torch.utils.data.DataLoader] = None, 251 callbacks: Optional[List[Type[torchmate.callbacks.Callback]]] = None, 252 device: Optional[Union[str, torch.device]] = None, 253 ): 254 """Perform predictions on the provided test data using the trained model. 255 256 This method enables you to make predictions on a test dataset using the trained model within your `Trainer` class. 257 258 Args: 259 test_dataloader (DataLoader, optional): A PyTorch DataLoader containing the test data. 260 If not provided, the ``self.test_dataloader`` attribute will be used. Defaults to ``None``. 261 callbacks (list[Callback], optional): A list of callback objects to be executed at various stages 262 of the prediction process. If not provided, the ``self.callbacks`` attribute will be used. 263 Defaults to ``None``. 264 device (str or torch.device, optional): The device to run the prediction on (e.g., "cpu" or "cuda"). 265 If not provided, the ``self.device`` attribute will be used. Defaults to ``None``. 266 267 Returns: 268 torch.Tensor: A PyTorch Tensor containing the predicted outputs for the test data. 269 270 Raises: 271 ValueError: If both ``test_dataloader`` and ``self.test_dataloader`` are None. 272 273 """ 274 275 if test_dataloader is None and self.test_dataloader is None: 276 raise ValueError( 277 "Missing validation data: You must provide either a `test_dataloader` argument or set a \ 278 `test_dataloader` attribute on the Trainer instance." 279 ) 280 281 model = self.model 282 device = device if device else self.device 283 callbacks = callbacks if callbacks else self.callbacks 284 device = device if device else self.device 285 286 if test_dataloader: 287 self.test_dataloader = test_dataloader 288 else: 289 test_dataloader = self.test_dataloader 290 291 model.eval() 292 progress_bar = ProgressBar(total=len(test_dataloader), prefix="prediction") 293 self.execute_callbacks(self, self.callbacks, "predict_begin") 294 predictions = [] 295 for batch_ix, X_test in enumerate(test_dataloader): 296 self.execute_callbacks(self, self.callbacks, "predict_batch_begin") 297 X_test = X_test[0].to(device) 298 with torch.inference_mode(): 299 y_pred = self.model(X_test) 300 predictions += y_pred 301 progress_bar.update(batch_ix + 1) 302 self.execute_callbacks(self, self.callbacks, "predict_batch_end") 303 self.execute_callbacks(self, self.callbacks, "predict_end") 304 return predictions
305 306 @staticmethod 307 def camel_to_snake_case(text: str) -> str: 308 """Convert CamelCase text to snake_case.""" 309 string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text) 310 string = re.sub("([a-z0-9])([A-Z])", r"\1_\2", string) 311 return string.lower() 312 313 @staticmethod 314 def gradient_norm(model, loss): 315 # Creates gradients 316 grad_params = torch.autograd.grad(outputs=loss, inputs=model.parameters(), create_graph=True) 317 # Computes the penalty term and adds it to the loss 318 grad_norm = 0 319 for grad in grad_params: 320 grad_norm += grad.pow(2).sum() 321 grad_norm = grad_norm.sqrt() 322 return grad_norm 323 324 @staticmethod 325 def execute_callbacks(trainer, callbacks=None, stage=""): 326 327 if callbacks is None: 328 return None 329 330 valid_stages = [ 331 "experiment_begin", 332 "experiment_end", 333 "epoch_begin", 334 "epoch_end", 335 "train_begin", 336 "train_end", 337 "train_batch_begin", 338 "train_batch_end", 339 "val_begin", 340 "val_end", 341 "val_batch_begin", 342 "val_batch_end", 343 "predict_begin", 344 "predict_end", 345 "predict_batch_begin", 346 "predict_batch_end", 347 "backward_end", 348 ] 349 350 if stage not in valid_stages: 351 raise ValueError(f"Invalid stage name. Must be one of {valid_stages}") 352 353 for callback in callbacks: 354 method = f"on_{stage}" 355 if hasattr(callback, method): 356 callback_method = getattr(callback, method) 357 callback_method(trainer) 358 return None 359 360 def train_and_evaluate(self): 361 362 self.model.to(self.device) 363 # running_avg_dict = dict() // assigned but never used, delete it 364 History = dict() 365 History["loss"] = [] 366 History["val_loss"] = [] 367 History["lr"] = [] 368 369 if self.metrics is not None: 370 for metric in self.metrics: 371 History[f"{metric.__name__}"] = [] 372 History[f"val_{metric.__name__}"] = [] 373 374 self.execute_callbacks(self, self.callbacks, "experiment_begin") 375 for epoch in range(1, self.num_epochs + 1): 376 etxt = f"Epoch {epoch}/{self.num_epochs}" 377 etxt = colorize_text(etxt, fore_tuple=(0, 0, 255), bold_text=True) 378 print(etxt, end="\n") 379 self.execute_callbacks(self, self.callbacks, "epoch_begin") 380 381 history = self.train_single_epoch( 382 self.model, 383 self.train_dataloader, 384 self.optimizer, 385 self.loss_fn, 386 self.metrics, 387 self.callbacks, 388 self.device, 389 ) 390 391 val_history = self.evaluate_single_epoch( 392 self.model, self.val_dataloader, self.loss_fn, self.metrics, self.callbacks, self.device 393 ) 394 395 # update history 396 for key in history.keys(): 397 History[key].append(history[key]) 398 for key in val_history.keys(): 399 History[key].append(val_history[key]) 400 self.history = History 401 ######################################################### 402 403 self.execute_callbacks(self, self.callbacks, "epoch_end") 404 if self.early_stop: 405 break 406 if self.update_params: # schedule learning rate only when the parameters are updated 407 if self.scheduler is not None: 408 if self.scheduler.__class__.__name__ == "ReduceLROnPlateau": 409 self.scheduler.step(val_history[self.schedule_monitor]) 410 else: 411 self.scheduler.step() 412 self.execute_callbacks(self, self.callbacks, "experiment_end") 413 414 return History 415 416 def train_single_epoch( 417 self, model, train_dataloader, optimizer, loss_fn, metrics=None, callbacks=None, device=None 418 ): 419 420 model.train() 421 422 progress_bar = ProgressBar(total=len(train_dataloader), prefix="train") 423 history = dict() 424 loss_avg = RunningAverage() 425 running_avg_dict = dict() 426 427 if metrics is not None: 428 for metric in metrics: 429 running_avg_dict[f"{metric.__name__}_avg"] = RunningAverage() 430 431 self.execute_callbacks(self, callbacks, "train_begin") 432 for batch_ix, (X_train, y_train) in enumerate(train_dataloader): 433 self.execute_callbacks(self, callbacks, "train_batch_begin") 434 X_train = X_train.to(device) 435 y_train = y_train.to(device) 436 with torch.autocast( 437 device_type=device, 438 dtype=torch.float16 if self.device != "cpu" else torch.bfloat16, 439 enabled=self.use_amp and self.device != "cpu", 440 ): 441 y_pred = model(X_train) 442 batch_loss = loss_fn(y_pred, y_train) 443 if self.accumulation_steps > 1 and not self.update_params: 444 batch_loss = batch_loss / self.accumulation_steps 445 if self.use_grad_penalty: 446 batch_loss = batch_loss + self.gradient_norm(model, batch_loss) 447 448 self.scaler.scale(batch_loss).backward() # batch_loss.backward() 449 if self.update_params: 450 self.execute_callbacks(self, callbacks, "backward_end") 451 self.scaler.step(optimizer) 452 self.scaler.update() # optimizer.step() 453 optimizer.zero_grad() 454 # update value + message 455 loss_avg.update(batch_loss.item()) 456 message = f"loss: {round(loss_avg(),5)}" 457 if metrics is not None: 458 for metric in metrics: 459 running_avg_dict[f"{metric.__name__}_avg"].update(metric(y_pred, y_train).item()) 460 metric_value = round(running_avg_dict[f"{metric.__name__}_avg"](), 5) 461 message += f" | {metric.__name__}: {metric_value}" 462 ############################################### 463 self.execute_callbacks(self, callbacks, "train_batch_end") 464 progress_bar.update(batch_ix + 1, message) 465 self.execute_callbacks(self, callbacks, "train_end") 466 467 # update history 468 history["lr"] = self.optimizer.param_groups[0]["lr"] 469 history["loss"] = loss_avg() 470 if metrics is not None: 471 for metric in metrics: 472 history[f"{metric.__name__}"] = running_avg_dict[f"{metric.__name__}_avg"]() 473 #################################### 474 475 return history 476 477 def evaluate_single_epoch(self, model, val_dataloader, loss_fn, metrics=None, callbacks=None, device=None): 478 479 model.eval() 480 progress_bar = ProgressBar(total=len(val_dataloader), prefix="valid") 481 val_loss_avg = RunningAverage() 482 running_avg_dict = dict() 483 history = dict() 484 prefix = "val_" 485 486 if metrics is not None: 487 for metric in metrics: 488 running_avg_dict[f"{prefix}{metric.__name__}_avg"] = RunningAverage() 489 490 self.execute_callbacks(self, callbacks, "val_begin") 491 for batch_ix, (X_val, y_val) in enumerate(val_dataloader): 492 self.execute_callbacks(self, callbacks, "val_batch_begin") 493 X_val = X_val.to(device) 494 y_val = y_val.to(device) 495 with torch.inference_mode(): 496 y_pred_val = model(X_val) 497 batch_val_loss = loss_fn(y_pred_val, y_val) 498 # update value + message 499 val_loss_avg.update(batch_val_loss.item()) 500 message = f"{prefix}loss: {round(val_loss_avg(), 5)}" 501 if metrics is not None: 502 for metric in metrics: 503 running_avg_dict[f"{prefix}{metric.__name__}_avg"].update(metric(y_pred_val, y_val).item()) 504 metric_value = round(running_avg_dict[f"{prefix}{metric.__name__}_avg"](), 5) 505 message += f" | {prefix}{metric.__name__}: {metric_value}" 506 ################################################ 507 self.execute_callbacks(self, callbacks, "val_batch_end") 508 progress_bar.update(batch_ix + 1, message) 509 self.execute_callbacks(self, callbacks, "val_end") 510 511 # update history 512 history[f"{prefix}loss"] = val_loss_avg() 513 if metrics is not None: 514 for metric in metrics: 515 history[f"{prefix}{metric.__name__}"] = running_avg_dict[f"{prefix}{metric.__name__}_avg"]() 516 #################################### 517 518 return history