1from typing import Optional
2
3from torchmate.callbacks import Callback
4
5
[docs]
6class GradientAccumulator(Callback):
7 """
8 Callback to accumulate gradients for gradient accumulation.
9
10 This callback accumulates gradients over a specified number of batches and performs an update
11 when the specified number of batches is reached.
12
13 Parameters:
14 num_accum_steps (int): Number of accumulation steps before performing a parameter update.
15
16 **Example Usage:**
17
18 .. code-block:: python
19
20 gradient_accumulator = GradientAccumulator(num_accum_steps=4)
21 callbacs = [gradient_accumulator]
22
23 """
24
25 def __init__(self, num_accum_steps: Optional[int] = None):
26 super().__init__()
27 self.batch_count = 0
28 self.num_accum_steps = num_accum_steps
29
[docs]
30 def on_experiment_begin(self, trainer):
31 trainer.update_params = False
32 if self.num_accum_steps is not None:
33 trainer.accumulation_steps = self.num_accum_steps
34
[docs]
35 def on_train_batch_begin(self, trainer):
36 self.batch_count += 1
37 if (self.batch_count % trainer.accumulation_steps == 0) or (self.batch_count == len(trainer.train_dataloader)):
38 trainer.update_params = True
39 else:
40 trainer.update_params = False
41
42 # def on_train_batch_end(self,trainer):
43 # self.batch_count += 1