Unexpectantly Large Memory Usage

I am getting an issue with my model using an unexpectedly large amount of RAM (around 15GB per sample when using 1 transformer block and 1 attention head). I have developed a UNet + Transformer architecture where the bottom of the UNet contains the transformer blocks. I want to use 6 transformer blocks with a 6-headed multihead attention. The attention and encoding size are both 16 and there are 1152 words. At the bottom of the UNet, the dimensions of the image has been downsampled to 64x18 with 128 channels and a batch size of 2. Before feeding into the transformer, I join the batch and channels dimensions and I also flatten the image before embedding into a 16-size vector for each word. This means that the dimensions for the inputs to the embedding layer are 256 (batches) x 1 (used to be channels) x 1152 (flattened image) x 1 and the outputs are 256 x 1 x 1152 x 16 (embedding size). This is then passed to the transformer. So far, none of these numbers stand out to me as being problematically large, though I feel like I am missing something. Is there anything that stands out as wasting or just generally consuming a lot of memory? Also, is there any way to conviently monitor peak GPU VRAM usage by module?

For the convolutional portions, I am using 16, 32, 64, and 128 for the number of channels and I am starting with relatively large images as inputs (1024 x 302). The image is complex so I am using complex versions of many of the layers found here: GitHub - wavefrontshaping/complexPyTorch: A high-level toolbox for using complex valued neural networks in PyTorch. These are my convolutional encoder and decoder blocks for downsampling and upsampling on the UNet:

class ConvEncoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvEncoderBlock, self).__init__()

        self.conv1 = nn.Sequential(ComplexConv2d(
            in_channels, out_channels, kernel_size=config.encoder_kernel_size, padding=config.encoder_padding),
            ComplexReLU())
        self.conv2 = nn.Sequential(ComplexConv2d(
            out_channels, out_channels, kernel_size=config.encoder_kernel_size, padding=config.encoder_padding),
            ComplexReLU())
        self.downsample = ComplexMaxPool2d(2)

    def forward(self, x):

        x = self.conv1(x)
        skip = self.conv2(x)
        x = self.downsample(skip)
        return x, skip

class DecoderBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()

        upsample_channels = in_channels * 2 // 3
        skip_channels = in_channels * 2 // 3

        self.upsample = ComplexConvTranspose2d(
            upsample_channels, upsample_channels // 2, kernel_size=2, stride=2)

        upsample_channels = upsample_channels // 2

        self.conv1 = nn.Sequential(ComplexConv2d(upsample_channels + skip_channels, out_channels, kernel_size=config.decoder_kernel_size, stride=1, padding=config.decoder_padding),
                                   ComplexReLU())
        self.conv2 = nn.Sequential(ComplexConv2d(out_channels, out_channels, kernel_size=config.decoder_kernel_size, stride=1, padding=config.decoder_padding),
                                   ComplexReLU())

    def forward(self, x, skip):

        # print('in decoder block', x.shape, skip.shape)

        # input: [batch, channel, freq, time]
        x = self.upsample(x)  # [batch, channel//2, freq*2, time*2]

        # print('catting', skip.shape, x.shape)
        x = cat(skip, x, dimension=1)

        # print('feeding into conv1', x.shape)
        x = self.conv1(x)  # [batch, channel, freq*2, time*2]
        x = self.conv2(x)  # [batch, channel, freq*2, time*2]

        # print('out of convolutionals', x.shape)

        return x

There are my transformer components:

class Embedding(nn.Module):

    def __init__(self, params, in_channels=1):
        super(Embedding, self).__init__()

        self.patch_embeddings = ComplexConv2d(
            in_channels, config.encoding_size, kernel_size=config.patch_size, stride=config.patch_size)
        self.positional_embeddings = nn.Parameter(torch.zeros(
            1, in_channels, config.num_patches, config.encoding_size, 2))

    # shape of x: [batch, channel, f, w] :=> type(torch.complex64)
    def forward(self, x):

        # print('embedding input shape', x.shape)

        # [batch * channels, 1, f, w] :=> type(torch.complex64)
        x = self.patch_embeddings(x)
        x = x.permute(0, 3, 2, 1)
        # [batch * channels, encoding_size, patches_per_column * w, 1] :=> type(torch.complex64)

        # [batch * channels, 1, patches_per_column * w, encoding_size] :=> type(torch.complex64)
        # x = x.permute(0, 2, 3, 1)
        # print('after permute', x.shape)

        x = torch.view_as_real(x)

        x = x + self.positional_embeddings

        x = torch.view_as_complex(x)

        # [batch * channels, 1, patches_per_column * w, encoding_size] :=> type(torch.complex64)
        return x


class AttentionHead(nn.Module):

    def __init__(self, in_channels=2, out_channels=1):
        super(AttentionHead, self).__init__()

        #self.num_heads = config.num_heads

        self.keys = ComplexLinear(
            config.encoding_size, config.attention_size, bias=config.attention_bias)
        self.queries = ComplexLinear(
            config.encoding_size, config.attention_size, bias=config.attention_bias)
        self.values = ComplexLinear(
            config.encoding_size, config.attention_size, bias=config.attention_bias)

        self.complex_map = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, x):

        keys = self.keys(x)
        queries = self.queries(x)
        values = self.values(x)

        scores = complex_matmul(queries, keys.transpose(-1, -2))
        scores /= config.attention_size ** 0.5

        scores = torch.view_as_real(scores)
        scores = scores[:, 0, :, :]
        scores = scores.permute(0, 3, 1, 2)
        scores = self.complex_map(scores)
        scores = nn.Softmax(dim=-1)(scores)

        scores = self.dropout(scores)

        scores = torch.complex(scores, torch.zeros_like(scores))

        #print('after attention', complex_matmul(scores, values))

        return complex_matmul(scores, values)


class MSA(nn.Module):
    def __init__(self, params):
        super(MSA, self).__init__()

        self.device = params["device"]

        self.heads = nn.ModuleList([AttentionHead(2, 1)
                                    for _ in range(config.num_heads)])

        self.w = ComplexLinear(config.attention_size * config.num_heads,
                               config.encoding_size, bias=config.attention_bias)

        self.dropout = ComplexDropout(config.dropout_rate)

    def forward(self, x):

        all_heads = self.heads[0](x)
        for i, head in enumerate(self.heads[1:]):
            all_heads = torch.cat((all_heads, head(x)), dim=-1)

        x = self.w(all_heads)
        x = self.dropout(x)

        #print('after MSA', x)

        return x


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.fc1 = ComplexLinear(config.encoding_size, config.mlp_size)
        self.fc2 = ComplexLinear(config.mlp_size, config.encoding_size)

        self.activation = ComplexReLU()

        self.dropout = ComplexDropout(config.dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = self.dropout(x)

        #print('after mlp', x)

        return x


class TransformerBlock(nn.Module):

    def __init__(self, params):
        super(TransformerBlock, self).__init__()

        self.attn_norm = NaiveComplexLayerNorm(
            (params["num_patches"], config.encoding_size), eps=config.norm_eps)
        self.attn = MSA(params)

        self.ffn_norm = NaiveComplexLayerNorm(
            (params["num_patches"], config.encoding_size), eps=config.norm_eps)
        self.ffn = MLP()

    def forward(self, x):

        # print('transformer input', x.shape)

        h = x
        x = self.attn_norm(x)
        x = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h

        # print('after transformer', x)

        return x

This is how I define my convolutional layers and transformer layers together in an overarching encoder class.

class Encoder(nn.Module):

    def __init__(self, params):
        super(Encoder, self).__init__()

        self.convEncoderBlock1 = ConvEncoderBlock(1, 16)
        self.convEncoderBlock2 = ConvEncoderBlock(16, 32)
        self.convEncoderBlock3 = ConvEncoderBlock(32, 64)
        self.convEncoderBlock4 = ConvEncoderBlock(64, 128)

        self.embedding = Embedding(params=params)

        self.transformers = nn.Sequential(OrderedDict(
            [("Block " + str(i), TransformerBlock(params)) for i in range(config.num_transformers)]))

        self.unembedding = nn.Sequential(
            ComplexConv2d(config.encoding_size, 1, kernel_size=1)
        )

    def forward(self, x):

        # print('encoder input shape', x.shape)

        # Convolutional Layers
        x, skip1 = self.convEncoderBlock1(x)
        x, skip2 = self.convEncoderBlock2(x)
        x, skip3 = self.convEncoderBlock3(x)
        x, skip4 = self.convEncoderBlock4(x)

        batch_size = x.shape[0]
        freq_size = x.shape[2]
        time_size = x.shape[3]

        # print("transformer input: ", x.shape)

        x = x.reshape((batch_size*x.shape[1], 1, x.shape[2], x.shape[3]))
        # print("channels to batches: ", x.shape)

        batch_channels = x.shape[0]


        x = x.permute(0, 1, 3, 2).reshape(batch_channels, 1, freq_size*time_size, 1)
        # print("transformer words as rows: ", x.shape)


        x = self.embedding(x)
        # print("embedding output: ", x.shape)

        x = self.transformers(x)
        # print("transformer output: ", x.shape)


        # Unembedding
        x = x.permute(0, 3, 2, 1)
        # print("unembed input: ", x.shape)

        x = self.unembedding(x)
        # print("unembed output: ", x.shape)

        x = x.reshape(batch_channels, 1, time_size,
                      freq_size).permute(0, 1, 3, 2)
        # print("reshape as spec: ", x.shape)

        x = x.reshape(
            (batch_size, x.shape[0]//batch_size, x.shape[2], x.shape[3]))
        # print("separate batch and channels: ", x.shape)

        return x, [skip1, skip2, skip3, skip4]