Source code for torchmate.callbacks.gradient_accumulator

 1from typing import Optional
 2
 3from torchmate.callbacks import Callback
 4
 5
[docs] 6class GradientAccumulator(Callback): 7 """ 8 Callback to accumulate gradients for gradient accumulation. 9 10 This callback accumulates gradients over a specified number of batches and performs an update 11 when the specified number of batches is reached. 12 13 Parameters: 14 num_accum_steps (int): Number of accumulation steps before performing a parameter update. 15 16 **Example Usage:** 17 18 .. code-block:: python 19 20 gradient_accumulator = GradientAccumulator(num_accum_steps=4) 21 callbacs = [gradient_accumulator] 22 23 """ 24 25 def __init__(self, num_accum_steps: Optional[int] = None): 26 super().__init__() 27 self.batch_count = 0 28 self.num_accum_steps = num_accum_steps 29
[docs] 30 def on_experiment_begin(self, trainer): 31 trainer.update_params = False 32 if self.num_accum_steps is not None: 33 trainer.accumulation_steps = self.num_accum_steps
34
[docs] 35 def on_train_batch_begin(self, trainer): 36 self.batch_count += 1 37 if (self.batch_count % trainer.accumulation_steps == 0) or (self.batch_count == len(trainer.train_dataloader)): 38 trainer.update_params = True 39 else: 40 trainer.update_params = False
41 42 # def on_train_batch_end(self,trainer): 43 # self.batch_count += 1