A question regarding memory usage

Hello all, I am working it self-attention in images as described in the SAGAN paper. The implementation is as follows -

class Attention(nn.Module):
    def __init__(self, query_channels, key_channels, value_channels, output_channels):
        super(Attention, self).__init__()

        self.queryConv = nn.Conv2d(in_channels=query_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
        self.keyConv = nn.Conv2d(in_channels=key_channels, out_channels=output_channels, kernel_size=3,
                                 stride=1, padding=1)
        self.valueConv = nn.Conv2d(in_channels=value_channels, out_channels=output_channels, kernel_size=3,
                                   stride=1, padding=1)

        self.outputConv = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=1,

        self.softmax = nn.Softmax(dim=-1)
        self.output_channels = output_channels
        self.queryChannels = query_channels

        self.norm = nn.InstanceNorm2d(num_features=output_channels)

    def forward(self, Query, key, value):
        batch, channels, width, height = Query.size()
        batch, channels_key, width_key, height_key = key.size()
        batch, channels_value, width_value, height_value = value.size()

        channels = channels_key = self.output_channels
        Q = self.queryConv(Query).view(batch, -1, width*height).permute(0, 2, 1)
        K = self.keyConv(key).view(batch, -1, width_key*height_key)

        energy = torch.bmm(Q, K)
        attention = self.softmax(energy)
        V = self.valueConv(value).view(batch, -1, width * height)
        out = torch.bmm(V, attention)
        out = out.view(batch, channels, width_key, height_key)

        out = self.outputConv(out)
        out = self.norm(out)
        return out

I run it the model twice -
one with inputs as -

attention = Attention(query_channels=256, key_channels=512, value_channels=512, output_channels=128).cuda()
a = torch.randn((1, 256, 320, 320)).cuda()
b = torch.randn((1, 512, 160, 160)).cuda()

and second with -

attention = Attention(query_channels=2, key_channels=2, value_channels=2, output_channels=1).cuda()
a = torch.randn((1, 2, 320, 320)).cuda()
b = torch.randn((1, 2, 160, 160)).cuda()

And this module is called as follows -

output = attention(a, b, a)

Where the second parameter in the forward controls the output dimension.
However in Both cases - I get the same error message which is -

  File "/main.py", line 279, in <module>
    output = attention(a, b, b)
  File ".local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
    /model/modules/Attention.py", line 42, in forward
    energy = torch.bmm(Q, K)
RuntimeError: CUDA out of memory. Tried to allocate 9.77 GiB (GPU 0; 7.80 GiB total capacity; 1.47 MiB already allocated; 6.76 GiB free; 2.00 MiB reserved in total by PyTorch)

What I am not able to understand is, why does that particular matrix multiplication take 9.7 GB to run.?
Given the size of tensors for multiplication would be 1*2*320*320*32 = 6553600 and 1*2*160*160*32=1638400 Bits which is 0.8192 and 0.2048 MB respectively. Does PyTorch take the rest of the space for tracking the gradients(I mean the ops required to make the computational graph)
Is there something wrong with my implementation?

Your code is unfortunately not executable and yields shape mismatch errors:

# first code snippet:
    out = out.view(batch, channels, width_key, height_key)
RuntimeError: shape '[1, 128, 160, 160]' is invalid for input of size 819200

# second code snippet:
    V = self.valueConv(value).view(batch, -1, width * height)
RuntimeError: shape '[1, -1, 102400]' is invalid for input of size 25600

I assume that you are calling it via out = attention(a, b, b) based on the error message.

Thanks for replying @ptrblck . I am so sorry for not putting this in the post.
It is called via out = attention(a, b, a). The output is in the shape of the key and the value and the query are the same. I will append this in the original post too

Edit - Initially I had assumed PyTorch might be storing intermediate tensors for grad calculation. But even in a no_grad region, I observe the same behavior, the same 9.77Gb comes up(for a 8Gb GPU). Can anyone reproduce this??