Source code for torchmate.utils.check_model_info

 1import torch
 2
 3
[docs] 4def model_info(model: torch.nn.Module): 5 """ 6 Analyze a PyTorch model and returns a dictionary containing information about its size and parameters. 7 8 Args: 9 model (torch.nn.Module): The PyTorch model to be analyzed. 10 11 Returns: 12 dict: A dictionary containing the following information: 13 total_params (int): Total number of parameters in the model. 14 trainable_params (int): Number of parameters that require gradient calculation during training. 15 non_trainable_params (int): Number of parameters that do not require gradient calculation during training. 16 total_size_mb (float): Total size of the model in megabytes. 17 param_size_mb (float): Size of the model parameters in megabytes (excluding buffers). 18 buffer_size_mb (float): Size of the model buffers in megabytes. 19 20 Prints: 21 This function also prints human-readable information about the model size and parameters 22 including total, trainable, and non-trainable parameters for better understanding. 23 """ 24 param_size = 0 25 for param in model.parameters(): 26 param_size += param.nelement() * param.element_size() 27 28 buffer_size = 0 29 for buffer in model.buffers(): 30 buffer_size += buffer.nelement() * buffer.element_size() 31 32 param_size_mb = param_size / 1024**2 33 buffer_size_mb = buffer_size / 1024**2 34 total_size_mb = (param_size + buffer_size) / 1024**2 35 36 total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) 37 trainable_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values()) 38 non_trainable_params = total_params - trainable_params 39 info = { 40 "total_params": total_params, 41 "trainable_params": trainable_params, 42 "non_trainable_params": non_trainable_params, 43 "total_size_mb": total_size_mb, 44 "param_size_mb": param_size_mb, 45 "buffer_size_mb": buffer_size_mb, 46 } 47 print(f"Total Parameters = {format(total_params,',')}") 48 print(f"Trainable Parameters = {format(trainable_params,',')}") 49 print(f"Non-trainable Parameters = {format(non_trainable_params,',')}") 50 print("=" * 50) 51 print("Model size: {:.3f} MB".format(total_size_mb)) 52 return info