β Callbacksο
Callbacks: Hooks into the Neural Network Training Process
In the realm of neural network training, callbacks play a crucial role by enabling you to intercept events and execute custom actions at various stages. This mechanism empowers you to tailor the training process to your specific needs, enhance learning performance, and gain valuable insights from the training data.
Why Use Callbacks?
Callbacks offer a plethora of benefits, including:
Monitoring Progress: Track metrics like loss and accuracy, visualize them dynamically, and even halt training early if certain criteria are met.
Saving Checkpoints: Regularly store snapshots of the modelβs state at different points in time, allowing you to resume training from a saved point or experiment with different hyperparameters.
Early Stopping: Prevent overfitting by automatically stopping training when validation performance starts to decline.
Logging: Record training details, metrics, and other information for analysis and evaluation.
Custom Actions: Implement specialized techniques or integrate with external services during training.
Callbackο
- class torchmate.callbacks.Callback[source]ο
Bases:
objectBase class for creating callback objects in an experimental or training framework.
Callbacks are used to customize and extend the behavior of an experiment, training loop, or optimization process by hooking into various stages of the execution.
Callback Methods:
on_experiment_begin(self, trainer: Trainer) -> NoneCalled at the beginning of an experiment.
on_experiment_end(self, trainer: Trainer) -> NoneCalled at the end of an experiment.
on_epoch_begin(self, trainer: Trainer) -> NoneCalled at the beginning of an epoch.
on_epoch_begin(self, trainer: Trainer) -> NoneCalled at the end of an epoch.
on_(train|val|predict)_begin(self, trainer: Trainer) -> NoneCalled at the beginning of fit/evaluate/predict.
on_(train|val|predict)_end(self, trainer: Trainer) -> NoneCalled at the end of fit/evaluate/predict.
on_(train|val|predict)_batch_begin(self, trainer: Trainer) -> NoneCalled right before processing a batch during training/validating/predicting.
on_(train|val|predict)_batch_end(self, trainer: Trainer) -> NoneCalled at the end of training/validating/predicting a batch.
- Parameters:
trainer (Trainer) - An instance of (torchmate.trainer.Trainer) β
Note
This base class provides empty implementations for all callback methods, allowing derived callback classes to selectively override only the methods that need to be customized.
- Example Usage:
Below is an example of a custom callback class that inherits from Callback and overrides specific methods to customize behavior during training:
class CustomCallback(Callback): def __init__(self): self.current_epoch = 0 def on_epoch_begin(self, trainer): self.current_epoch +=1 print(f"Epoch {self.current_epoch} begins!") def on_epoch_end(self, trainer): print(f"Epoch {self.current_epoch} has finished!") def on_experiment_end(self, trainer): print("Experiment finished!") print(f"History: {trainer.history}") # Create an instance of the custom callback and use it during training custom_callback = CustomCallback()
CSVLoggerο
- class torchmate.callbacks.CSVLogger(filename: str, separator: str = ',', append: bool = False)[source]ο
Bases:
CallbackLogs training history to a CSV file.
- Parameters:
filename (str) β The path to the CSV file to log to.
separator (str, optional) β The field separator to use in the CSV file. Defaults to
",".append (bool, optional) β Whether to append to an existing file or create a new one. Defaults to
False.
Example Usage:
csv_logger = CSVLogger(filename="logs/log.csv") callbacks = [csv_logger]
- on_experiment_begin(trainer) None[source]ο
Initialize the logging process at the beginning of the experiment.
- Parameters:
trainer (Trainer) β The trainer object.
- Returns:
None
EarlyStopperο
- class torchmate.callbacks.EarlyStopper(monitor='val_loss', patience=3, mode='min', min_delta=0, restore_best_state=False)[source]ο
Bases:
CallbackCallback to implement early stopping based on a monitored metric.
This callback stops the training early if the monitored metric does not improve for a specified number of epochs (patience).
- Parameters:
monitor (str) β Metric to monitor for early stopping. Defaults to
"val_loss".patience (int) β Number of epochs with no improvement to wait before stopping. Defaults to
3.mode (str) β One of
"min"or"max". The direction to monitor the metric. For example,"min"for loss,"max"for accuracy. Defaults to"min"min_delta (float) β Minimum change in the monitored metric to qualify as an improvement. Defauls to
0.restore_best_state (bool) β Whether to restore the best model state when early stopping. Defaults to
False.
Example Usage:
early_stopper = EarlyStopper( monitor="val_loss", patience=3, mode="min" )
GradientAccumulatorο
- class torchmate.callbacks.GradientAccumulator(num_accum_steps: int | None = None)[source]ο
Bases:
CallbackCallback to accumulate gradients for gradient accumulation.
This callback accumulates gradients over a specified number of batches and performs an update when the specified number of batches is reached.
- Parameters:
num_accum_steps (int) β Number of accumulation steps before performing a parameter update.
Example Usage:
gradient_accumulator = GradientAccumulator(num_accum_steps=4) callbacs = [gradient_accumulator]
GradientClipperο
- class torchmate.callbacks.GradientClipper(method: str, clip_value: float | None = None, max_norm: float | None = None)[source]ο
Bases:
CallbackGradientClipper class for clipping gradients during the training process.
- Parameters:
method (str) β The gradient clipping method to use. Supported methods are βclip_by_valueβ and βclip_by_normβ.
clip_value (Optional[float]) β Maximum allowed value of gradients for clip_by_value method. Required if method is βclip_by_valueβ.
max_norm (Optional[float]) β Maximum allowed norm of gradients for clip_by_norm method. Required if method is βclip_by_normβ.
- Raises:
ValueError β If the provided method is not supported.
ValueError β If clip_value is not provided when method is βclip_by_valueβ.
ValueError β If max_norm is not provided when method is βclip_by_normβ.
Example Usage:
gradient_clipper = GradientClipper(method="clip_by_value", clip_value=1.0) callbacks = [gradient_clipper]
ModelCheckpointο
- class torchmate.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', min_delta=0.0, save_frequency=1, save_best_only=True, save_state_dict_only=True, save_last=True)[source]ο
Bases:
CallbackCallback to save model checkpoints.
This callback saves model checkpoints at specified intervals and under certain conditions such as best validation loss or metrics.
- Parameters:
checkpoint_dir (str) β Directory to save the checkpoints.
monitor (str, optional) β Metric to monitor for saving the best model. Defaults to
"val_loss".mode (str, optional) β One of
"min"or"max". The direction to monitor the metric. For example,"min"for loss,"max"for accuracy. Defaults to"min".min_delta (float, optional) β Minimum change in the monitored metric to qualify as an improvement. Defaults to
0.0.save_frequency (int, optional) β Frequency of saving checkpoints (epochs). Defaults to
1.save_best_only (bool, optional) β Whether to save only the best model based on the monitored metric. Defaults to
True.save_last (bool, optional) β Whether to save the model checkpoint for the last epoch. Defaults to
True.save_state_dict_only (bool, optional) β
Whether to save only the modelβs state dictionary instead of the whole model. Defaults to
True.- If
True: Saves a dictionary containing epoch, model_state_dict, and optimizer_state_dict.
Requires manually creating model and optimizer instances before loading the dictionary.
Only the modelβs weights and biases are loaded from model_state_dict.
- If
- If
False: Saves the entire model object.
- If
Example Usage:
checkpoint_callback = ModelCheckpoint( checkpoint_dir="checkpoints", monitor="val_loss", mode="min", save_best_only = False, save_frequency=2, ) callbacks = [checkpoint_callback]
WandbLoggerο
- class torchmate.callbacks.WandbLogger[source]ο
Bases:
CallbackCallback to log training metrics and visualizations to Weights & Biases (Wandb).
This callback logs training metrics, visualizations, and other experiment-related data to Wandb.
Example Usage:
wandb_logger = WandbLogger() callbacks = [wandb_logger]
WandbModelCheckpointο
- class torchmate.callbacks.WandbModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', min_delta=0, save_frequency=1, save_best_only=True, save_state_dict_only=True, save_last=True)[source]ο
Bases:
CallbackCallback to save model checkpoints using Weights and Biases (wandb).
This callback saves model checkpoints in wandb cloud at specified intervals and under certain conditions such as best validation loss or metrics.
- Parameters:
checkpoint_dir (str) β Directory to save the checkpoints.
monitor (str, optional) β Metric to monitor for saving the best model. Defaults to
"val_loss".mode (str, optional) β One of
"min"or"max". The direction to monitor the metric. For example,"min"for loss,"max"for accuracy. Defaults to"min".min_delta (float, optional) β Minimum change in the monitored metric to qualify as an improvement. Defaults to
0.0.save_frequency (int, optional) β Frequency of saving checkpoints (epochs). Defaults to
1.save_best_only (bool, optional) β Whether to save only the best model based on the monitored metric. Defaults to
True.save_last (bool, optional) β Whether to save the model checkpoint for the last epoch. Defaults to
True.save_state_dict_only (bool, optional) β
Whether to save only the modelβs state dictionary instead of the whole model. Defaults to
True.- If
True: Saves a dictionary containing epoch, model_state_dict, and optimizer_state_dict.
Requires manually creating model and optimizer instances before loading the dictionary.
Only the modelβs weights and biases are loaded from model_state_dict.
- If
- If
False: Saves the entire model object.
- If
Example Usage:
wandb_checkpoint_callback = WandbModelCheckpoint( checkpoint_dir="checkpoints", monitor="val_loss", mode="min", save_best_only = False, save_frequency=2, ) callbacks = [wandb_checkpoint_callback]