1from typing import Optional, Union
2
3import torch
4
5from torchmate.callbacks import Callback
6
7
[docs]
8class GradientClipper(Callback):
9 """
10 GradientClipper class for clipping gradients during the training process.
11
12 Args:
13 method (str): The gradient clipping method to use.
14 Supported methods are "clip_by_value" and "clip_by_norm".
15 clip_value (Optional[float]): Maximum allowed value of gradients for
16 clip_by_value method. Required if method is "clip_by_value".
17 max_norm (Optional[float]): Maximum allowed norm of gradients for
18 clip_by_norm method. Required if method is "clip_by_norm".
19
20 Raises:
21 ValueError: If the provided method is not supported.
22 ValueError: If clip_value is not provided when method is "clip_by_value".
23 ValueError: If max_norm is not provided when method is "clip_by_norm".
24
25 **Example Usage:**
26
27 .. code-block:: python
28
29 gradient_clipper = GradientClipper(method="clip_by_value", clip_value=1.0)
30 callbacks = [gradient_clipper]
31
32 """
33
34 def __init__(
35 self,
36 method: str,
37 clip_value: Optional[float] = None,
38 max_norm: Optional[float] = None,
39 ) -> None:
40
41 supported_methods = ["clip_by_value", "clip_by_norm"]
42 if method not in supported_methods:
43 raise ValueError(
44 f"Unsupported gradient clipping method '{method}'. Supported methods are: {', '.join(supported_methods)}"
45 )
46
47 self.method = method
48 self.clip_value = clip_value
49 self.max_norm = max_norm
50
51 if self.method == "clip_by_value" and self.clip_value is None:
52 raise ValueError(f"clip_value must be provided when method is '{method}'")
53
54 if self.method == "clip_by_norm" and self.max_norm is None:
55 raise ValueError(f"max_norm must be provided when method is '{method}'")
56
[docs]
57 def on_backward_end(self, trainer) -> None:
58 """Clips gradients at the end of the backward pass."""
59 trainer.scaler.unscale_(trainer.optimizer)
60
61 if self.method == "clip_by_norm":
62 torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), max_norm=self.max_norm)
63
64 elif self.method == "clip_by_value":
65 torch.nn.utils.clip_grad_value_(trainer.model.parameters(), clip_value=self.clip_value)