Source code for torchmate.callbacks.gradient_clipper

 1from typing import Optional, Union
 2
 3import torch
 4
 5from torchmate.callbacks import Callback
 6
 7
[docs] 8class GradientClipper(Callback): 9 """ 10 GradientClipper class for clipping gradients during the training process. 11 12 Args: 13 method (str): The gradient clipping method to use. 14 Supported methods are "clip_by_value" and "clip_by_norm". 15 clip_value (Optional[float]): Maximum allowed value of gradients for 16 clip_by_value method. Required if method is "clip_by_value". 17 max_norm (Optional[float]): Maximum allowed norm of gradients for 18 clip_by_norm method. Required if method is "clip_by_norm". 19 20 Raises: 21 ValueError: If the provided method is not supported. 22 ValueError: If clip_value is not provided when method is "clip_by_value". 23 ValueError: If max_norm is not provided when method is "clip_by_norm". 24 25 **Example Usage:** 26 27 .. code-block:: python 28 29 gradient_clipper = GradientClipper(method="clip_by_value", clip_value=1.0) 30 callbacks = [gradient_clipper] 31 32 """ 33 34 def __init__( 35 self, 36 method: str, 37 clip_value: Optional[float] = None, 38 max_norm: Optional[float] = None, 39 ) -> None: 40 41 supported_methods = ["clip_by_value", "clip_by_norm"] 42 if method not in supported_methods: 43 raise ValueError( 44 f"Unsupported gradient clipping method '{method}'. Supported methods are: {', '.join(supported_methods)}" 45 ) 46 47 self.method = method 48 self.clip_value = clip_value 49 self.max_norm = max_norm 50 51 if self.method == "clip_by_value" and self.clip_value is None: 52 raise ValueError(f"clip_value must be provided when method is '{method}'") 53 54 if self.method == "clip_by_norm" and self.max_norm is None: 55 raise ValueError(f"max_norm must be provided when method is '{method}'") 56
[docs] 57 def on_backward_end(self, trainer) -> None: 58 """Clips gradients at the end of the backward pass.""" 59 trainer.scaler.unscale_(trainer.optimizer) 60 61 if self.method == "clip_by_norm": 62 torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), max_norm=self.max_norm) 63 64 elif self.method == "clip_by_value": 65 torch.nn.utils.clip_grad_value_(trainer.model.parameters(), clip_value=self.clip_value)