Visualizing attention map of self attention integrated in CNN

Hello, I am trying to visualize the attention map after the last layer of my model
my model is custom CNN where self attention is being integrated
I searched everywhere for this and found nothing
Could anyone kindly help? Thanks in advance.

class FeatureSqueezingLayer(nn.Module):
    def __init__(self, bit_depth):
        super(FeatureSqueezingLayer, self).__init__()
        self.bit_depth = bit_depth

    def forward(self, x):
        quantized_x = torch.floor(x * (2 ** self.bit_depth)) / (2 ** self.bit_depth)
        return quantized_x

class SimplifiedSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SimplifiedSelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = F.softmax(energy, dim=-1)
        value = self.value(x).view(batch_size, -1, height * width)
        x = torch.bmm(value, attention.permute(0, 2, 1))
        x = x.view(batch_size, channels, height, width)
        x = self.gamma * x + x

        feature_squeeze = FeatureSqueezingLayer(bit_depth=8)
        x = feature_squeeze(x)

        return x

class ResidualBlockWithTransformerAttention(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlockWithTransformerAttention, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
                        nn.BatchNorm2d(out_channels),
                       
                        nn.SELU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(out_channels))
        self.attention = SimplifiedSelfAttention(out_channels)  # Using simplified self-attention
        self.downsample = downsample
        
        self.selu = nn.SELU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.attention(out) 
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        
        out = self.selu(out)
        return out


class ResNetWithAttention(nn.Module):
    def __init__(self, block, layers, num_classes=2):
        super(ResNetWithAttention, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
                        nn.BatchNorm2d(64),
                        
                        nn.SELU())
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

model = ResNetWithAttention(ResidualBlockWithTransformerAttention, [3, 4, 6, 3], num_classes=2)

Could you point out what exactly you want to visualize?
E.g. if you want to visualize a forward activation you could return it in the forward method of your model or you could use forward hooks to store the (detached) activation. Once done you could then use matplotlib to visualize the tensor but might need to visualize each slice separately if the tensor is not in a valid image shape.