Pytorch class activation maps issue

Hey everyone,

I’m encountering an issue while plotting class activation maps (CAMs) for a PyTorch neural network model. Specifically, I’m getting the following error:

RuntimeError: cannot register a hook on a tensor that doesn't require gradient.

I hope someone can help me clarify why this error occurs and how to resolve it :blush:. The plot_cam_model_list_images() function is utilized within the training loop for supervised classification to visualize CAMs on the test set.

To replicate the error, I’ve created a test file and copied the code to Google Colab: link to Colab notebook.

Here’s a simplified version of the code:

import matplotlib.pyplot as plt
import torch
import torchcam.methods as methods
from torchcam.utils import overlay_mask
from torchvision.transforms import ToPILImage
from torchvision.transforms.functional import to_pil_image
import torchvision.models as models
import torch.nn as nn

# Define ResNet18 model
class ResNet18(torch.nn.Module):
    def __init__(self, input_channels=1, output_channels=2):
        super().__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(
            7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet.fc = nn.Linear(
            in_features=512, out_features=output_channels, bias=True)

    def forward(self, x):
        return self.resnet(x)

def plot_cam_model_list_images(model, list_images):
    """
    Plot images with their corresponding CAM (Class Activation Mapping) overlays for a given model and list of images.
    Plots as a 2*number of images subplot with the first row containing the images and the second row containing the CAM overlays.
    Args:
        model (torch.nn.Module): The neural network model.
        list_images (list): List of images to visualize.
    Returns:
        None
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_shape = list_images[0].shape

    for param in model.parameters():
        param.requires_grad_(True)

    # Choose CAM method
    cam = methods.GradCAMpp(model=model, input_shape=input_shape)

    fig, ax = plt.subplots(nrows=2, ncols=len(list_images))
    ax = ax.flatten()
    plot_imgs_list = []
    plot_cams_list = []

    for i, image in enumerate(list_images):
        image.requires_grad = True
        image = image.to(device)

        # Forward pass
        out = model(image.unsqueeze(0))

        # Create a CAM extractor
        activation_map = cam(out.squeeze(0).argmax().item(), out)
        input_tensor_rgb = torch.cat([image, image, image], dim=0)

        # Move tensors to CPU
        input_tensor_rgb = input_tensor_rgb.cpu()

        # Overlay CAM on the input image
        cam_plot = overlay_mask(to_pil_image(input_tensor_rgb), to_pil_image(activation_map[0].squeeze(0), mode='F'),
                                alpha=0.5)

        plot_imgs_list.append(ToPILImage()(image))
        plot_cams_list.append(cam_plot)

    reordered_list_img_plot = plot_imgs_list + plot_cams_list
    for i, img in enumerate(reordered_list_img_plot):
        ax[i].imshow(img)
        ax[i].axis(False)
    plt.axis('off')
    plt.show()

# Usage example
model = ResNet18(1, 2)

for param in model.parameters():
    param.requires_grad = True

model.eval()

list_images1 = [torch.rand((1, 512, 512)) for i in range(4)]
plot_cam_model_list_images(model, list_images1)

list_images2 = [torch.rand((1, 512, 512)) for i in range(4)]
plot_cam_model_list_images(model, list_images2)

Upon running the plot_cam_model_list_images() function for the second time, I encounter the aforementioned RuntimeError. The traceback leads to the definition of CAM with the error being triggered during the registration of a hook.
Full traceback:

Traceback (most recent call last):
  File "c:\Users\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 93, in <module>
    plot_cam_model_list_images(
  File "c:\Users\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 44, in plot_cam_model_list_images
    cam = methods.GradCAMpp(model=model, input_shape=input_shape)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\gradient.py", line 33, in __init__
    super().__init__(model, target_layer, input_shape, **kwargs)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\core.py", line 54, in __init__
    target_name = locate_candidate_layer(model, input_shape)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\_utils.py", line 42, in locate_candidate_layer
    _ = mod(torch.zeros((1, *input_shape), device=next(mod.parameters()).data.device))
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "c:\Users\\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 22, in forward
    return self.resnet(x)
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchvision\models\resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchvision\models\resnet.py", line 276, in _forward_impl
    x = self.layer4(x)
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1574, in _call_impl
    hook_result = hook(self, args, result)
  File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\gradient.py", line 49, in _hook_g
    self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx)))
  File "C:\python\python39enviroments\torch\lib\site-packages\torch\_tensor.py", line 562, in register_hook
    raise RuntimeError(
RuntimeError: cannot register a hook on a tensor that doesn't require gradient

This issue appears similar to this discussion, but in my case, the model weights are not frozen.My model is a Pytorch ResNet18 changed for a given input and output channel dimensions (grayscale input and 2 classes at the end). In the real life main train function I would load the best checkpoint of the model before CAMs plotting.

Any assistance on resolving this issue would be greatly appreciated.

Thank you!

Best regards,