Hi everyone,
I am implementing a CNN to differentiate two types of images and I would like to use the grad-cam library GitHub - jacobgil/pytorch-grad-cam: Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more. to explain the predictions of the network.
The network I used is a pre-trained VGG followed by two fully connected layers.
I need to tell the grad-cam function the target layer, but I can’t figure out the right way to do it.
On the website they suggest this syntax target_layer = model.features[-1] . But as I used nn.sequential, it doesn’t work.
Do you know how to do it in this case?
I attach my code below.
Thanks a lot for your help.
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import pandas as pd
import cv2
from torchvision import models
import numpy as np
from torch.autograd import Variable
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
vgg16 = models.vgg16(pretrained=True)
mod = nn.Sequential(*list(vgg16.children())[:-1])
class Net(nn.Module):
def init(self):
super(Net,self).init()
#img = images
self.fc1=nn.Linear(51277,32)
self.fc2=nn.Linear(32,2)
def forward(self,x):
x = x.view(x.size(0), 512*7*7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model=nn.Sequential(mod,Net())
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
target_layers = list(model.children())[0][:-1] #this is not good…
cam = HiResCAM(model=model, target_layers=target_layers, use_cuda= False)
grayscale_cam = cam(input_tensor=input_tensor.unsqueeze(0))
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(test_image, grayscale_cam)
imgplot = plt.imshow(visualization)
plt.show()