Source code for torchmate.modules.refined_self_attention_sagan

 1import torch
 2
 3
[docs] 4class RefinedSelfAttentionSAGAN(torch.nn.Module): 5 """Implement a Refined Self-Attention module for Generative Adversarial Networks (SAGANs). 6 7 This module has a more efficient time and space complexity of O(n) compared to the original SAGAN self-attention's 8 O(n^2) complexity, making it suitable for reducing computational overhead. 9 10 Parameters: 11 in_channels (int): Number of input channels. 12 kernel_size (int, optional): Size of the convolutional kernels. Defaults to ``3``. 13 dilation (int, optional): Dilation factor for the convolutional layers. Defaults to ``1``. 14 padding (str, optional): Padding type for the convolutional layers. Defaults to ``"same"``. 15 bias (bool, optional): Whether to use bias in the convolutional layers. Defaults to ``False``. 16 scale (int, optional): Scaling factor for the number of channels in the query and key projections. Defaults to ``8``. 17 18 Reference: 19 - Paper: Zheng et al., Less Memory, Faster Speed: Refining Self-Attention Module for Image Reconstruction.\ 20 arxiv: https://arxiv.org/abs/1905.08008 21 - Implementation: This is an original implementation by the author of Torchmate. 22 23 """ 24 25 def __init__(self, in_channels, kernel_size=3, dilation=1, padding="same", bias=False, scale=8): 26 super().__init__() 27 self.query = torch.nn.Conv2d( 28 in_channels=in_channels, 29 out_channels=in_channels // scale, 30 kernel_size=kernel_size, 31 dilation=dilation, 32 padding=padding, 33 bias=bias, 34 ) 35 self.key = torch.nn.Conv2d( 36 in_channels=in_channels, 37 out_channels=in_channels // scale, 38 kernel_size=kernel_size, 39 dilation=dilation, 40 padding=padding, 41 bias=bias, 42 ) 43 self.value = torch.nn.Conv2d( 44 in_channels=in_channels, 45 out_channels=in_channels, 46 kernel_size=kernel_size, 47 dilation=dilation, 48 padding=padding, 49 bias=bias, 50 ) 51 self.gamma = torch.nn.Parameter(torch.tensor(0.0)) 52
[docs] 53 @staticmethod 54 def hw_flatten(x): 55 return x.view(x.shape[0], x.shape[1], -1)
56
[docs] 57 def forward(self, x): 58 q = self.hw_flatten(self.query(x)) # [bs, c', n] ; n = HxW 59 k = self.hw_flatten(self.key(x)) # [bs, c', n] ; bs = batch size 60 v = self.hw_flatten(self.value(x)) # [bs, c, n] 61 kv = torch.bmm(k, v.permute((0, 2, 1))) # [bs, c, c'] 62 norm = kv / self.hw_flatten(x).shape[2] 63 attention = torch.bmm(norm.permute((0, 2, 1)), q) 64 attention = attention.view_as(x) 65 res = x + attention 66 return res