βš“ Callbacks

Callbacks: Hooks into the Neural Network Training Process

In the realm of neural network training, callbacks play a crucial role by enabling you to intercept events and execute custom actions at various stages. This mechanism empowers you to tailor the training process to your specific needs, enhance learning performance, and gain valuable insights from the training data.

Why Use Callbacks?

Callbacks offer a plethora of benefits, including:

  • Monitoring Progress: Track metrics like loss and accuracy, visualize them dynamically, and even halt training early if certain criteria are met.

  • Saving Checkpoints: Regularly store snapshots of the model’s state at different points in time, allowing you to resume training from a saved point or experiment with different hyperparameters.

  • Early Stopping: Prevent overfitting by automatically stopping training when validation performance starts to decline.

  • Logging: Record training details, metrics, and other information for analysis and evaluation.

  • Custom Actions: Implement specialized techniques or integrate with external services during training.

Callback

class torchmate.callbacks.Callback[source]

Bases: object

Base class for creating callback objects in an experimental or training framework.

Callbacks are used to customize and extend the behavior of an experiment, training loop, or optimization process by hooking into various stages of the execution.

Callback Methods:

on_experiment_begin(self, trainer: Trainer) -> None

Called at the beginning of an experiment.

on_experiment_end(self, trainer: Trainer) -> None

Called at the end of an experiment.

on_epoch_begin(self, trainer: Trainer) -> None

Called at the beginning of an epoch.

on_epoch_begin(self, trainer: Trainer) -> None

Called at the end of an epoch.

on_(train|val|predict)_begin(self, trainer: Trainer) -> None

Called at the beginning of fit/evaluate/predict.

on_(train|val|predict)_end(self, trainer: Trainer) -> None

Called at the end of fit/evaluate/predict.

on_(train|val|predict)_batch_begin(self, trainer: Trainer) -> None

Called right before processing a batch during training/validating/predicting.

on_(train|val|predict)_batch_end(self, trainer: Trainer) -> None

Called at the end of training/validating/predicting a batch.

Parameters:

trainer (Trainer) - An instance of (torchmate.trainer.Trainer) –

Note

This base class provides empty implementations for all callback methods, allowing derived callback classes to selectively override only the methods that need to be customized.

Example Usage:

Below is an example of a custom callback class that inherits from Callback and overrides specific methods to customize behavior during training:

class CustomCallback(Callback):
    def __init__(self):
        self.current_epoch = 0

    def on_epoch_begin(self, trainer):
        self.current_epoch +=1
        print(f"Epoch {self.current_epoch} begins!")

    def on_epoch_end(self, trainer):
        print(f"Epoch {self.current_epoch} has finished!")

    def on_experiment_end(self, trainer):
        print("Experiment finished!")
        print(f"History: {trainer.history}")

# Create an instance of the custom callback and use it during training
custom_callback = CustomCallback()
on_experiment_begin(trainer) None[source]
on_experiment_end(trainer) None[source]
on_epoch_begin(trainer) None[source]
on_epoch_end(trainer) None[source]
on_train_begin(trainer) None[source]
on_train_end(trainer) None[source]
on_train_batch_begin(trainer) None[source]
on_train_batch_end(trainer) None[source]
on_val_begin(trainer) None[source]
on_val_end(trainer) None[source]
on_val_batch_begin(trainer) None[source]
on_val_batch_end(trainer) None[source]
on_predict_begin(trainer) None[source]
on_predict_end(trainer) None[source]
on_predict_batch_begin(trainer) None[source]
on_predict_batch_end(trainer) None[source]
on_backward_end(trainer) None[source]

CSVLogger

class torchmate.callbacks.CSVLogger(filename: str, separator: str = ',', append: bool = False)[source]

Bases: Callback

Logs training history to a CSV file.

Parameters:
  • filename (str) – The path to the CSV file to log to.

  • separator (str, optional) – The field separator to use in the CSV file. Defaults to ",".

  • append (bool, optional) – Whether to append to an existing file or create a new one. Defaults to False.

Example Usage:

csv_logger = CSVLogger(filename="logs/log.csv")
callbacks = [csv_logger]
on_experiment_begin(trainer) None[source]

Initialize the logging process at the beginning of the experiment.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

on_epoch_end(trainer) None[source]

Log training metrics after each epoch.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

on_experiment_end(trainer) None[source]

Close the CSV file at the end of the experiment.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

EarlyStopper

class torchmate.callbacks.EarlyStopper(monitor='val_loss', patience=3, mode='min', min_delta=0, restore_best_state=False)[source]

Bases: Callback

Callback to implement early stopping based on a monitored metric.

This callback stops the training early if the monitored metric does not improve for a specified number of epochs (patience).

Parameters:
  • monitor (str) – Metric to monitor for early stopping. Defaults to "val_loss".

  • patience (int) – Number of epochs with no improvement to wait before stopping. Defaults to 3.

  • mode (str) – One of "min" or "max". The direction to monitor the metric. For example, "min" for loss, "max" for accuracy. Defaults to "min"

  • min_delta (float) – Minimum change in the monitored metric to qualify as an improvement. Defauls to 0.

  • restore_best_state (bool) – Whether to restore the best model state when early stopping. Defaults to False.

Example Usage:

early_stopper = EarlyStopper(
    monitor="val_loss",
    patience=3,
    mode="min"
)
on_epoch_end(trainer) None[source]

Check for early stopping conditions after each epoch and if conditions are met stops training.

Parameters:

trainer (Trainer) – The trainer object.

Returns:

None

GradientAccumulator

class torchmate.callbacks.GradientAccumulator(num_accum_steps: int | None = None)[source]

Bases: Callback

Callback to accumulate gradients for gradient accumulation.

This callback accumulates gradients over a specified number of batches and performs an update when the specified number of batches is reached.

Parameters:

num_accum_steps (int) – Number of accumulation steps before performing a parameter update.

Example Usage:

gradient_accumulator = GradientAccumulator(num_accum_steps=4)
callbacs = [gradient_accumulator]
on_experiment_begin(trainer)[source]
on_train_batch_begin(trainer)[source]

GradientClipper

class torchmate.callbacks.GradientClipper(method: str, clip_value: float | None = None, max_norm: float | None = None)[source]

Bases: Callback

GradientClipper class for clipping gradients during the training process.

Parameters:
  • method (str) – The gradient clipping method to use. Supported methods are β€œclip_by_value” and β€œclip_by_norm”.

  • clip_value (Optional[float]) – Maximum allowed value of gradients for clip_by_value method. Required if method is β€œclip_by_value”.

  • max_norm (Optional[float]) – Maximum allowed norm of gradients for clip_by_norm method. Required if method is β€œclip_by_norm”.

Raises:
  • ValueError – If the provided method is not supported.

  • ValueError – If clip_value is not provided when method is β€œclip_by_value”.

  • ValueError – If max_norm is not provided when method is β€œclip_by_norm”.

Example Usage:

gradient_clipper = GradientClipper(method="clip_by_value", clip_value=1.0)
callbacks = [gradient_clipper]
on_backward_end(trainer) None[source]

Clips gradients at the end of the backward pass.

ModelCheckpoint

class torchmate.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', min_delta=0.0, save_frequency=1, save_best_only=True, save_state_dict_only=True, save_last=True)[source]

Bases: Callback

Callback to save model checkpoints.

This callback saves model checkpoints at specified intervals and under certain conditions such as best validation loss or metrics.

Parameters:
  • checkpoint_dir (str) – Directory to save the checkpoints.

  • monitor (str, optional) – Metric to monitor for saving the best model. Defaults to "val_loss".

  • mode (str, optional) – One of "min" or "max". The direction to monitor the metric. For example, "min" for loss, "max" for accuracy. Defaults to "min".

  • min_delta (float, optional) – Minimum change in the monitored metric to qualify as an improvement. Defaults to 0.0.

  • save_frequency (int, optional) – Frequency of saving checkpoints (epochs). Defaults to 1.

  • save_best_only (bool, optional) – Whether to save only the best model based on the monitored metric. Defaults to True.

  • save_last (bool, optional) – Whether to save the model checkpoint for the last epoch. Defaults to True.

  • save_state_dict_only (bool, optional) –

    Whether to save only the model’s state dictionary instead of the whole model. Defaults to True.

    • If True:
      • Saves a dictionary containing epoch, model_state_dict, and optimizer_state_dict.

      • Requires manually creating model and optimizer instances before loading the dictionary.

      • Only the model’s weights and biases are loaded from model_state_dict.

    • If False:
      • Saves the entire model object.

Example Usage:

checkpoint_callback = ModelCheckpoint(
    checkpoint_dir="checkpoints",
    monitor="val_loss",
    mode="min",
    save_best_only = False,
    save_frequency=2,
)
callbacks = [checkpoint_callback]
on_epoch_end(trainer)[source]

WandbLogger

class torchmate.callbacks.WandbLogger[source]

Bases: Callback

Callback to log training metrics and visualizations to Weights & Biases (Wandb).

This callback logs training metrics, visualizations, and other experiment-related data to Wandb.

Example Usage:

wandb_logger = WandbLogger()
callbacks = [wandb_logger]
on_experiment_begin(trainer)[source]
on_epoch_end(trainer)[source]

WandbModelCheckpoint

class torchmate.callbacks.WandbModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', min_delta=0, save_frequency=1, save_best_only=True, save_state_dict_only=True, save_last=True)[source]

Bases: Callback

Callback to save model checkpoints using Weights and Biases (wandb).

This callback saves model checkpoints in wandb cloud at specified intervals and under certain conditions such as best validation loss or metrics.

Parameters:
  • checkpoint_dir (str) – Directory to save the checkpoints.

  • monitor (str, optional) – Metric to monitor for saving the best model. Defaults to "val_loss".

  • mode (str, optional) – One of "min" or "max". The direction to monitor the metric. For example, "min" for loss, "max" for accuracy. Defaults to "min".

  • min_delta (float, optional) – Minimum change in the monitored metric to qualify as an improvement. Defaults to 0.0.

  • save_frequency (int, optional) – Frequency of saving checkpoints (epochs). Defaults to 1.

  • save_best_only (bool, optional) – Whether to save only the best model based on the monitored metric. Defaults to True.

  • save_last (bool, optional) – Whether to save the model checkpoint for the last epoch. Defaults to True.

  • save_state_dict_only (bool, optional) –

    Whether to save only the model’s state dictionary instead of the whole model. Defaults to True.

    • If True:
      • Saves a dictionary containing epoch, model_state_dict, and optimizer_state_dict.

      • Requires manually creating model and optimizer instances before loading the dictionary.

      • Only the model’s weights and biases are loaded from model_state_dict.

    • If False:
      • Saves the entire model object.

Example Usage:

wandb_checkpoint_callback = WandbModelCheckpoint(
    checkpoint_dir="checkpoints",
    monitor="val_loss",
    mode="min",
    save_best_only = False,
    save_frequency=2,
)
callbacks = [wandb_checkpoint_callback]
on_experiment_begin(trainer)[source]
on_epoch_end(trainer)[source]