Visualising Neural Nets, simple implementation

I’m trying to implement part of the idea in the paper

What I find with this approach is that I have to implement each deconvolution manually, and I don’t really see how to automate this more.

When I run the deconvolution, it flows back to the image, which now only looks pixelated.

This is the current network I am using

import torch
from torch import nn


class Backbone(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=2, stride=1, padding=2)
        self.batch1 = nn.BatchNorm2d(16)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3, 3), return_indices=True)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=2)
        self.batch2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=2)
        self.batch3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(64, 128, kernel_size=2, stride=1, padding=2)
        self.batch4 = nn.BatchNorm2d(128)
        self.maxpool = nn.MaxPool2d(kernel_size=(3, 3))

        self.conv5 = nn.Conv2d(128, 256, kernel_size=2, stride=1, padding=2)
        self.batch5 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU()

        self.maxupool = nn.MaxUnpool2d(kernel_size=(3, 3))

    def deconv1(
        self,
        image_tensor: torch.Tensor,
    ) -> torch.Tensor:
        activations = {}
        kernels = {}
        indices = {}
        sizes = {}
        name = "maxpool1"
        name1 = "conv1"

        def get_activation_and_indices(name):
            def hook(model, input, output):
                activations[name] = output[0].detach()
                indices[name] = output[1].detach()
                # print(input[0].size())
                sizes[name] = input[0].size()

            return hook

        self.maxpool1.register_forward_hook(get_activation_and_indices(name))
        kernels[name1] = self.conv1.weight.detach()

        self.forward(image_tensor)
        x = self.maxupool(activations[name], indices[name], output_size=sizes[name])
        x = self.relu(x)
        x = nn.functional.conv_transpose2d(x, kernels[name1], stride=1, padding=2)
        x = x.mul(255.0).clamp(0.0, 255.0)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu(x)
        x, _ = self.maxpool1(x)
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.batch3(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv4(x)
        x = self.batch4(x)
        x = self.relu(x)
        x = self.conv5(x)
        x = self.batch5(x)
        x = self.relu(x)
        return x

Is there any better way to frame this problem of deconvolution so that it becomes more automatic ?