Gradients don't flow back during guided backprop

Following is my custom NN:

CNN(
(seq_model): Sequential(
(Conv2d_1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=same)
(BN_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_1): ReLU()
(Maxpool2d_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv2d_2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=same)
(BN_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_2): ReLU()
(Maxpool2d_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv2d_3): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=same)
(BN_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_3): ReLU()
(Maxpool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv2d_4): Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1), padding=same)
(BN_4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_4): ReLU()
(Maxpool2d_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv2d_5): Conv2d(512, 1024, kernel_size=(5, 5), stride=(1, 1), padding=same)
(BN_5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_5): ReLU()
(Maxpool2d_5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Linear(in_features=16384, out_features=128, bias=True)
(Dropout): Dropout(p=0, inplace=False)
(fcReLU): ReLU()
(head): Linear(in_features=128, out_features=10, bias=True)
)
)

I am experimenting with guided backprop and this is my code:

class Guided_backprop:
    def __init__(self, model, utils_agent):
        self.model = model
        self.image_reconstruction = None # store R0
        self.activation_maps = []  # store f1, f2, ...
        for _, p in self.model.named_parameters():
            p.requires_grad = True
        self.model.eval()
        self.register_hooks()

    def register_hooks(self):
        def first_layer_hook_fn(module, grad_in, grad_out):
            self.image_reconstruction = grad_in[0] 

        def forward_hook_fn(module, input, output):
            self.activation_maps.append(output)

        def backward_hook_fn(module, grad_in, grad_out):
            grad = self.activation_maps.pop() 
            # for the forward pass, after the ReLU operation, 
            # if the output value is positive, we set the value to 1,
            # and if the output value is negative, we set it to 0.
            grad[grad > 0] = 1 
            
            # grad_out[0] stores the gradients for each feature map,
            # and we only retain the positive gradients
            new_grad_in = grad * torch.clamp(grad_out[0], min=0.0)
            return (new_grad_in,)

        modules = []
        for module in self.model.seq_model.named_children():
            modules.append(module)

        # travese the modules,register forward hook & backward hook
        # for the ReLU
        for name, module in modules:
            if isinstance(module, nn.ReLU):
                module.register_forward_hook(forward_hook_fn)
                module.register_backward_hook(backward_hook_fn)

        # register backward hook for the first conv layer
        first_layer = modules[0][1]
        first_layer.register_backward_hook(first_layer_hook_fn)

    def visualize(self, datapoint):
        def normalize(image):
            norm = (image - image.mean())/image.std()
            norm = norm * 0.1
            norm = norm + 0.5
            norm = norm.clip(0, 1)
            return norm

        input_image, _ = datapoint
        target_class = None
        input_image = input_image.unsqueeze(0).requires_grad_().to(device)
        model_output = self.model(input_image)
        self.model.zero_grad()
        pred_class = model_output.argmax().item()
        
        grad_target_map = torch.zeros(model_output.shape, dtype=torch.float, device=device)
        
        if target_class is not None:
            grad_target_map[0][target_class.argmax(0).item()] = 1
        else:
            grad_target_map[0][pred_class] = 1
        
        model_output.backward(gradient=grad_target_map)
        input_image = input_image.squeeze(0)
        result = self.image_reconstruction.data[0].permute(1,2,0)
        print("Img reconst", self.image_reconstruction.shape)
        result = normalize(result)
        gbp_result = wandb.Image(result.cpu().numpy(), caption='Guided BP Image')
        orig_img = wandb.Image(self.utils_agent.invTransf(input_image).cpu(), caption='Original Image')

        wandb.log({'Orig_img': orig_img, 'GBP_Result': gbp_result})

My image reconstruction is of the size:

torch.Size([1, 64, 128, 128])

which is the same as my conv1 layer output. This means gradients aren’t flowing back through the input layer. My input image size is: torch.Size([1, 3, 128, 128]). Where am I going wrong?