1"""A collection of squeeze and excitation (SE) layers that can be integrated into various neural network architectures.
2
3Supports three types of SE blocks:
4- Channel Squeeze and Excitation (CSE)
5- Spatial Squeeze and Excitation (SSE)
6- Channel and Spatial Squeeze and Excitation (CSSE)
7
8**Credit:**
9https://github.com/ai-med/squeeze_and_excitation
10
11
12References:
13- [Hu et al., Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
14- [Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018]\
15 (https://arxiv.org/abs/1803.02579)
16
17
18"""
19
20from enum import Enum
21
22import torch
23
24
[docs]
25class ChannelSELayer(torch.nn.Module):
26 """
27 Implements the Channel Squeeze and Excitation (CSE) block as described in:
28
29 - Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507
30
31 Parameters:
32 num_channels (int): Number of input channels.
33 reduction_ratio (int, optional): Ratio to reduce the number of channels
34 by in the squeeze step. Default is ``2``.
35
36 Returns:
37 torch.Tensor: Output tensor with the same dimensions as
38 the input.
39
40 Reference:
41 - Paper: https://arxiv.org/abs/1709.01507
42 - Implementation: https://github.com/ai-med/squeeze_and_excitation
43
44 """
45
46 def __init__(self, num_channels, reduction_ratio=2):
47 super(ChannelSELayer, self).__init__()
48 num_channels_reduced = num_channels // reduction_ratio
49 self.reduction_ratio = reduction_ratio
50 self.fc1 = torch.nn.Linear(num_channels, num_channels_reduced, bias=True)
51 self.fc2 = torch.nn.Linear(num_channels_reduced, num_channels, bias=True)
52 self.relu = torch.nn.ReLU()
53 self.sigmoid = torch.nn.Sigmoid()
54
[docs]
55 def forward(self, input_tensor):
56 """Forward method
57
58 Parameters:
59 input_tensor: X, shape = (batch_size, num_channels, H, W)
60
61 Returns:
62 torch.Tensor: Output tensor with the same dimensions as the input.
63 """
64
65 batch_size, num_channels, H, W = input_tensor.size()
66 # Average along each channel
67 squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
68
69 # channel excitation
70 fc_out_1 = self.relu(self.fc1(squeeze_tensor))
71 fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
72
73 a, b = squeeze_tensor.size()
74 output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
75 return output_tensor
76
77
[docs]
78class SpatialSELayer(torch.nn.Module):
79 """
80 Implement the Spatial Squeeze and Excitation (SSE) block as described in:
81
82 - Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks,\
83 MICCAI 2018
84
85 Parameters:
86 num_channels (int): Number of input channels.
87
88 Returns:
89 torch.Tensor: Output tensor with the same dimensions as
90 the input.
91
92 Reference:
93 - Paper: https://arxiv.org/abs/1803.02579
94 - Implementation: https://github.com/ai-med/squeeze_and_excitation
95
96 """
97
98 def __init__(self, num_channels):
99 super(SpatialSELayer, self).__init__()
100 self.conv = torch.nn.Conv2d(num_channels, 1, 1)
101 self.sigmoid = torch.nn.Sigmoid()
102
[docs]
103 def forward(self, input_tensor, weights=None):
104 """Forward method
105
106 Parameters:
107 input_tensor: X, shape = (batch_size, num_channels, H, W)
108 weights: weights for few shot learning
109
110 Returns:
111 torch.Tensor: Output tensor with the same dimensions as the input.
112 """
113 # spatial squeeze
114 batch_size, channel, a, b = input_tensor.size()
115
116 if weights is not None:
117 weights = torch.mean(weights, dim=0)
118 weights = weights.view(1, channel, 1, 1)
119 out = torch.nn.functional.conv2d(input_tensor, weights)
120 else:
121 out = self.conv(input_tensor)
122 squeeze_tensor = self.sigmoid(out)
123
124 # spatial excitation
125 # print(input_tensor.size(), squeeze_tensor.size())
126 squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b)
127 output_tensor = torch.mul(input_tensor, squeeze_tensor)
128 # output_tensor = torch.mul(input_tensor, squeeze_tensor)
129 return output_tensor
130
131
[docs]
132class ChannelSpatialSELayer(torch.nn.Module):
133 """
134 Implement the Concurrent Spatial and Channel Squeeze & Excitation (CSSE) block as described in:
135
136 - Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks,\
137 MICCAI 2018
138
139 Parameters:
140 num_channels (int): Number of input channels.
141 reduction_ratio (int, optional): Ratio to reduce the number of channels
142 by in the squeeze step. Default is ``2``.
143
144 Returns:
145 torch.Tensor: Output tensor with the same dimensions as the input.
146
147 Reference:
148 - Paper: https://arxiv.org/abs/1803.02579
149 - Implementation: https://github.com/ai-med/squeeze_and_excitation
150
151
152 """
153
154 def __init__(self, num_channels, reduction_ratio=2):
155 super(ChannelSpatialSELayer, self).__init__()
156 self.cSE = ChannelSELayer(num_channels, reduction_ratio)
157 self.sSE = SpatialSELayer(num_channels)
158
[docs]
159 def forward(self, input_tensor):
160 """Forward method
161
162 Parameters:
163 input_tensor: X, shape = (batch_size, num_channels, H, W)
164
165 Returns:
166 torch.Tensor: Output tensor with the same dimensions as the input.
167
168 """
169 output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
170 return output_tensor
171
172
[docs]
173class SELayer(Enum):
174 """Squeeze and Excitation Enum Block
175
176 Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to
177 a neural network.
178
179 .. code-block:: python
180
181 if self.se_block_type == se.SELayer.CSE.value:
182 self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
183
184 elif self.se_block_type == se.SELayer.SSE.value:
185 self.SELayer = se.SpatialSELayer(params['num_filters'])
186
187 elif self.se_block_type == se.SELayer.CSSE.value:
188 self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
189
190
191 Reference:
192 Implementation: https://github.com/ai-med/squeeze_and_excitation
193 """
194
195 NONE = "NONE"
196 CSE = "CSE"
197 SSE = "SSE"
198 CSSE = "CSSE"