Hello all, I am working it self-attention in images as described in the SAGAN
paper. The implementation is as follows -
class Attention(nn.Module):
def __init__(self, query_channels, key_channels, value_channels, output_channels):
super(Attention, self).__init__()
self.queryConv = nn.Conv2d(in_channels=query_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
self.keyConv = nn.Conv2d(in_channels=key_channels, out_channels=output_channels, kernel_size=3,
stride=1, padding=1)
self.valueConv = nn.Conv2d(in_channels=value_channels, out_channels=output_channels, kernel_size=3,
stride=1, padding=1)
self.outputConv = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=1,
stride=1)
self.softmax = nn.Softmax(dim=-1)
self.output_channels = output_channels
self.queryChannels = query_channels
self.norm = nn.InstanceNorm2d(num_features=output_channels)
def forward(self, Query, key, value):
batch, channels, width, height = Query.size()
batch, channels_key, width_key, height_key = key.size()
batch, channels_value, width_value, height_value = value.size()
channels = channels_key = self.output_channels
Q = self.queryConv(Query).view(batch, -1, width*height).permute(0, 2, 1)
K = self.keyConv(key).view(batch, -1, width_key*height_key)
energy = torch.bmm(Q, K)
attention = self.softmax(energy)
V = self.valueConv(value).view(batch, -1, width * height)
out = torch.bmm(V, attention)
out = out.view(batch, channels, width_key, height_key)
out = self.outputConv(out)
out = self.norm(out)
return out
I run it the model twice -
one with inputs as -
attention = Attention(query_channels=256, key_channels=512, value_channels=512, output_channels=128).cuda()
a = torch.randn((1, 256, 320, 320)).cuda()
b = torch.randn((1, 512, 160, 160)).cuda()
and second with -
attention = Attention(query_channels=2, key_channels=2, value_channels=2, output_channels=1).cuda()
a = torch.randn((1, 2, 320, 320)).cuda()
b = torch.randn((1, 2, 160, 160)).cuda()
And this module is called as follows -
output = attention(a, b, a)
Where the second parameter in the forward controls the output dimension.
However in Both cases - I get the same error message which is -
File "/main.py", line 279, in <module>
output = attention(a, b, b)
File ".local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
/model/modules/Attention.py", line 42, in forward
energy = torch.bmm(Q, K)
RuntimeError: CUDA out of memory. Tried to allocate 9.77 GiB (GPU 0; 7.80 GiB total capacity; 1.47 MiB already allocated; 6.76 GiB free; 2.00 MiB reserved in total by PyTorch)
What I am not able to understand is, why does that particular matrix multiplication take 9.7 GB to run.?
Given the size of tensors for multiplication would be 1*2*320*320*32 = 6553600
and 1*2*160*160*32=1638400
Bits which is 0.8192 and 0.2048 MB respectively. Does PyTorch take the rest of the space for tracking the gradients(I mean the ops required to make the computational graph)
Is there something wrong with my implementation?
TIA