Hey everyone,
I’m encountering an issue while plotting class activation maps (CAMs) for a PyTorch neural network model. Specifically, I’m getting the following error:
RuntimeError: cannot register a hook on a tensor that doesn't require gradient.
I hope someone can help me clarify why this error occurs and how to resolve it . The plot_cam_model_list_images()
function is utilized within the training loop for supervised classification to visualize CAMs on the test set.
To replicate the error, I’ve created a test file and copied the code to Google Colab: link to Colab notebook.
Here’s a simplified version of the code:
import matplotlib.pyplot as plt
import torch
import torchcam.methods as methods
from torchcam.utils import overlay_mask
from torchvision.transforms import ToPILImage
from torchvision.transforms.functional import to_pil_image
import torchvision.models as models
import torch.nn as nn
# Define ResNet18 model
class ResNet18(torch.nn.Module):
def __init__(self, input_channels=1, output_channels=2):
super().__init__()
self.resnet = models.resnet18(pretrained=False)
self.resnet.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(
7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_channels, bias=True)
def forward(self, x):
return self.resnet(x)
def plot_cam_model_list_images(model, list_images):
"""
Plot images with their corresponding CAM (Class Activation Mapping) overlays for a given model and list of images.
Plots as a 2*number of images subplot with the first row containing the images and the second row containing the CAM overlays.
Args:
model (torch.nn.Module): The neural network model.
list_images (list): List of images to visualize.
Returns:
None
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_shape = list_images[0].shape
for param in model.parameters():
param.requires_grad_(True)
# Choose CAM method
cam = methods.GradCAMpp(model=model, input_shape=input_shape)
fig, ax = plt.subplots(nrows=2, ncols=len(list_images))
ax = ax.flatten()
plot_imgs_list = []
plot_cams_list = []
for i, image in enumerate(list_images):
image.requires_grad = True
image = image.to(device)
# Forward pass
out = model(image.unsqueeze(0))
# Create a CAM extractor
activation_map = cam(out.squeeze(0).argmax().item(), out)
input_tensor_rgb = torch.cat([image, image, image], dim=0)
# Move tensors to CPU
input_tensor_rgb = input_tensor_rgb.cpu()
# Overlay CAM on the input image
cam_plot = overlay_mask(to_pil_image(input_tensor_rgb), to_pil_image(activation_map[0].squeeze(0), mode='F'),
alpha=0.5)
plot_imgs_list.append(ToPILImage()(image))
plot_cams_list.append(cam_plot)
reordered_list_img_plot = plot_imgs_list + plot_cams_list
for i, img in enumerate(reordered_list_img_plot):
ax[i].imshow(img)
ax[i].axis(False)
plt.axis('off')
plt.show()
# Usage example
model = ResNet18(1, 2)
for param in model.parameters():
param.requires_grad = True
model.eval()
list_images1 = [torch.rand((1, 512, 512)) for i in range(4)]
plot_cam_model_list_images(model, list_images1)
list_images2 = [torch.rand((1, 512, 512)) for i in range(4)]
plot_cam_model_list_images(model, list_images2)
Upon running the plot_cam_model_list_images()
function for the second time, I encounter the aforementioned RuntimeError
. The traceback leads to the definition of CAM with the error being triggered during the registration of a hook.
Full traceback:
Traceback (most recent call last):
File "c:\Users\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 93, in <module>
plot_cam_model_list_images(
File "c:\Users\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 44, in plot_cam_model_list_images
cam = methods.GradCAMpp(model=model, input_shape=input_shape)
File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\gradient.py", line 33, in __init__
super().__init__(model, target_layer, input_shape, **kwargs)
File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\core.py", line 54, in __init__
target_name = locate_candidate_layer(model, input_shape)
File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\_utils.py", line 42, in locate_candidate_layer
_ = mod(torch.zeros((1, *input_shape), device=next(mod.parameters()).data.device))
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "c:\Users\\PythonWorkspace\VSCProjects\ml-incubator\source\deep_learning\main_scripts\test_CAM_script.py", line 22, in forward
return self.resnet(x)
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "C:\python\python39enviroments\torch\lib\site-packages\torchvision\models\resnet.py", line 285, in forward
return self._forward_impl(x)
File "C:\python\python39enviroments\torch\lib\site-packages\torchvision\models\resnet.py", line 276, in _forward_impl
x = self.layer4(x)
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\python\python39enviroments\torch\lib\site-packages\torch\nn\modules\module.py", line 1574, in _call_impl
hook_result = hook(self, args, result)
File "C:\python\python39enviroments\torch\lib\site-packages\torchcam\methods\gradient.py", line 49, in _hook_g
self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx)))
File "C:\python\python39enviroments\torch\lib\site-packages\torch\_tensor.py", line 562, in register_hook
raise RuntimeError(
RuntimeError: cannot register a hook on a tensor that doesn't require gradient
This issue appears similar to this discussion, but in my case, the model weights are not frozen.My model is a Pytorch ResNet18 changed for a given input and output channel dimensions (grayscale input and 2 classes at the end). In the real life main train function I would load the best checkpoint of the model before CAMs plotting.
Any assistance on resolving this issue would be greatly appreciated.
Thank you!
Best regards,