1import torch
2
3
[docs]
4def model_info(model: torch.nn.Module):
5 """
6 Analyze a PyTorch model and returns a dictionary containing information about its size and parameters.
7
8 Args:
9 model (torch.nn.Module): The PyTorch model to be analyzed.
10
11 Returns:
12 dict: A dictionary containing the following information:
13 total_params (int): Total number of parameters in the model.
14 trainable_params (int): Number of parameters that require gradient calculation during training.
15 non_trainable_params (int): Number of parameters that do not require gradient calculation during training.
16 total_size_mb (float): Total size of the model in megabytes.
17 param_size_mb (float): Size of the model parameters in megabytes (excluding buffers).
18 buffer_size_mb (float): Size of the model buffers in megabytes.
19
20 Prints:
21 This function also prints human-readable information about the model size and parameters
22 including total, trainable, and non-trainable parameters for better understanding.
23 """
24 param_size = 0
25 for param in model.parameters():
26 param_size += param.nelement() * param.element_size()
27
28 buffer_size = 0
29 for buffer in model.buffers():
30 buffer_size += buffer.nelement() * buffer.element_size()
31
32 param_size_mb = param_size / 1024**2
33 buffer_size_mb = buffer_size / 1024**2
34 total_size_mb = (param_size + buffer_size) / 1024**2
35
36 total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
37 trainable_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
38 non_trainable_params = total_params - trainable_params
39 info = {
40 "total_params": total_params,
41 "trainable_params": trainable_params,
42 "non_trainable_params": non_trainable_params,
43 "total_size_mb": total_size_mb,
44 "param_size_mb": param_size_mb,
45 "buffer_size_mb": buffer_size_mb,
46 }
47 print(f"Total Parameters = {format(total_params,',')}")
48 print(f"Trainable Parameters = {format(trainable_params,',')}")
49 print(f"Non-trainable Parameters = {format(non_trainable_params,',')}")
50 print("=" * 50)
51 print("Model size: {:.3f} MB".format(total_size_mb))
52 return info