1import torch
2
3
[docs]
4class RefinedSelfAttentionSAGAN(torch.nn.Module):
5 """Implement a Refined Self-Attention module for Generative Adversarial Networks (SAGANs).
6
7 This module has a more efficient time and space complexity of O(n) compared to the original SAGAN self-attention's
8 O(n^2) complexity, making it suitable for reducing computational overhead.
9
10 Parameters:
11 in_channels (int): Number of input channels.
12 kernel_size (int, optional): Size of the convolutional kernels. Defaults to ``3``.
13 dilation (int, optional): Dilation factor for the convolutional layers. Defaults to ``1``.
14 padding (str, optional): Padding type for the convolutional layers. Defaults to ``"same"``.
15 bias (bool, optional): Whether to use bias in the convolutional layers. Defaults to ``False``.
16 scale (int, optional): Scaling factor for the number of channels in the query and key projections. Defaults to ``8``.
17
18 Reference:
19 - Paper: Zheng et al., Less Memory, Faster Speed: Refining Self-Attention Module for Image Reconstruction.\
20 arxiv: https://arxiv.org/abs/1905.08008
21 - Implementation: This is an original implementation by the author of Torchmate.
22
23 """
24
25 def __init__(self, in_channels, kernel_size=3, dilation=1, padding="same", bias=False, scale=8):
26 super().__init__()
27 self.query = torch.nn.Conv2d(
28 in_channels=in_channels,
29 out_channels=in_channels // scale,
30 kernel_size=kernel_size,
31 dilation=dilation,
32 padding=padding,
33 bias=bias,
34 )
35 self.key = torch.nn.Conv2d(
36 in_channels=in_channels,
37 out_channels=in_channels // scale,
38 kernel_size=kernel_size,
39 dilation=dilation,
40 padding=padding,
41 bias=bias,
42 )
43 self.value = torch.nn.Conv2d(
44 in_channels=in_channels,
45 out_channels=in_channels,
46 kernel_size=kernel_size,
47 dilation=dilation,
48 padding=padding,
49 bias=bias,
50 )
51 self.gamma = torch.nn.Parameter(torch.tensor(0.0))
52
[docs]
53 @staticmethod
54 def hw_flatten(x):
55 return x.view(x.shape[0], x.shape[1], -1)
56
[docs]
57 def forward(self, x):
58 q = self.hw_flatten(self.query(x)) # [bs, c', n] ; n = HxW
59 k = self.hw_flatten(self.key(x)) # [bs, c', n] ; bs = batch size
60 v = self.hw_flatten(self.value(x)) # [bs, c, n]
61 kv = torch.bmm(k, v.permute((0, 2, 1))) # [bs, c, c']
62 norm = kv / self.hw_flatten(x).shape[2]
63 attention = torch.bmm(norm.permute((0, 2, 1)), q)
64 attention = attention.view_as(x)
65 res = x + attention
66 return res