Source code for torchmate.modules.squeeze_and_excitation

  1"""A collection of squeeze and excitation (SE) layers that can be integrated into various neural network architectures.
  2
  3Supports three types of SE blocks:
  4- Channel Squeeze and Excitation (CSE)
  5- Spatial Squeeze and Excitation (SSE)
  6- Channel and Spatial Squeeze and Excitation (CSSE)
  7
  8**Credit:**
  9https://github.com/ai-med/squeeze_and_excitation
 10
 11
 12References:
 13- [Hu et al., Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
 14- [Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018]\
 15  (https://arxiv.org/abs/1803.02579)
 16
 17
 18"""
 19
 20from enum import Enum
 21
 22import torch
 23
 24
[docs] 25class ChannelSELayer(torch.nn.Module): 26 """ 27 Implements the Channel Squeeze and Excitation (CSE) block as described in: 28 29 - Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507 30 31 Parameters: 32 num_channels (int): Number of input channels. 33 reduction_ratio (int, optional): Ratio to reduce the number of channels 34 by in the squeeze step. Default is ``2``. 35 36 Returns: 37 torch.Tensor: Output tensor with the same dimensions as 38 the input. 39 40 Reference: 41 - Paper: https://arxiv.org/abs/1709.01507 42 - Implementation: https://github.com/ai-med/squeeze_and_excitation 43 44 """ 45 46 def __init__(self, num_channels, reduction_ratio=2): 47 super(ChannelSELayer, self).__init__() 48 num_channels_reduced = num_channels // reduction_ratio 49 self.reduction_ratio = reduction_ratio 50 self.fc1 = torch.nn.Linear(num_channels, num_channels_reduced, bias=True) 51 self.fc2 = torch.nn.Linear(num_channels_reduced, num_channels, bias=True) 52 self.relu = torch.nn.ReLU() 53 self.sigmoid = torch.nn.Sigmoid() 54
[docs] 55 def forward(self, input_tensor): 56 """Forward method 57 58 Parameters: 59 input_tensor: X, shape = (batch_size, num_channels, H, W) 60 61 Returns: 62 torch.Tensor: Output tensor with the same dimensions as the input. 63 """ 64 65 batch_size, num_channels, H, W = input_tensor.size() 66 # Average along each channel 67 squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 68 69 # channel excitation 70 fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 71 fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 72 73 a, b = squeeze_tensor.size() 74 output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 75 return output_tensor
76 77
[docs] 78class SpatialSELayer(torch.nn.Module): 79 """ 80 Implement the Spatial Squeeze and Excitation (SSE) block as described in: 81 82 - Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks,\ 83 MICCAI 2018 84 85 Parameters: 86 num_channels (int): Number of input channels. 87 88 Returns: 89 torch.Tensor: Output tensor with the same dimensions as 90 the input. 91 92 Reference: 93 - Paper: https://arxiv.org/abs/1803.02579 94 - Implementation: https://github.com/ai-med/squeeze_and_excitation 95 96 """ 97 98 def __init__(self, num_channels): 99 super(SpatialSELayer, self).__init__() 100 self.conv = torch.nn.Conv2d(num_channels, 1, 1) 101 self.sigmoid = torch.nn.Sigmoid() 102
[docs] 103 def forward(self, input_tensor, weights=None): 104 """Forward method 105 106 Parameters: 107 input_tensor: X, shape = (batch_size, num_channels, H, W) 108 weights: weights for few shot learning 109 110 Returns: 111 torch.Tensor: Output tensor with the same dimensions as the input. 112 """ 113 # spatial squeeze 114 batch_size, channel, a, b = input_tensor.size() 115 116 if weights is not None: 117 weights = torch.mean(weights, dim=0) 118 weights = weights.view(1, channel, 1, 1) 119 out = torch.nn.functional.conv2d(input_tensor, weights) 120 else: 121 out = self.conv(input_tensor) 122 squeeze_tensor = self.sigmoid(out) 123 124 # spatial excitation 125 # print(input_tensor.size(), squeeze_tensor.size()) 126 squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b) 127 output_tensor = torch.mul(input_tensor, squeeze_tensor) 128 # output_tensor = torch.mul(input_tensor, squeeze_tensor) 129 return output_tensor
130 131
[docs] 132class ChannelSpatialSELayer(torch.nn.Module): 133 """ 134 Implement the Concurrent Spatial and Channel Squeeze & Excitation (CSSE) block as described in: 135 136 - Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks,\ 137 MICCAI 2018 138 139 Parameters: 140 num_channels (int): Number of input channels. 141 reduction_ratio (int, optional): Ratio to reduce the number of channels 142 by in the squeeze step. Default is ``2``. 143 144 Returns: 145 torch.Tensor: Output tensor with the same dimensions as the input. 146 147 Reference: 148 - Paper: https://arxiv.org/abs/1803.02579 149 - Implementation: https://github.com/ai-med/squeeze_and_excitation 150 151 152 """ 153 154 def __init__(self, num_channels, reduction_ratio=2): 155 super(ChannelSpatialSELayer, self).__init__() 156 self.cSE = ChannelSELayer(num_channels, reduction_ratio) 157 self.sSE = SpatialSELayer(num_channels) 158
[docs] 159 def forward(self, input_tensor): 160 """Forward method 161 162 Parameters: 163 input_tensor: X, shape = (batch_size, num_channels, H, W) 164 165 Returns: 166 torch.Tensor: Output tensor with the same dimensions as the input. 167 168 """ 169 output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) 170 return output_tensor
171 172
[docs] 173class SELayer(Enum): 174 """Squeeze and Excitation Enum Block 175 176 Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to 177 a neural network. 178 179 .. code-block:: python 180 181 if self.se_block_type == se.SELayer.CSE.value: 182 self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 183 184 elif self.se_block_type == se.SELayer.SSE.value: 185 self.SELayer = se.SpatialSELayer(params['num_filters']) 186 187 elif self.se_block_type == se.SELayer.CSSE.value: 188 self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 189 190 191 Reference: 192 Implementation: https://github.com/ai-med/squeeze_and_excitation 193 """ 194 195 NONE = "NONE" 196 CSE = "CSE" 197 SSE = "SSE" 198 CSSE = "CSSE"