I want to compare the predicted label with true label and if the predicted label was not mach, plot GradCAM image and original image in beside each other. I developed a pice of code but I received an error.
It is my code:
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
print(labels)
print(predicted)
for i in range(len(predicted)):
if predicted[i] != labels[i]:
resize_size =(227,227)
directory = ‘/content/drive/MyDrive/Herbarium Data-2019/Data/sample/val’
files = os.listdir(directory)
for foldername in files:
img_files = os.listdir(‘/content/drive/MyDrive/Herbarium Data-2019/Data/sample/val/’+foldername)
for filename in img_files:
input_image = Image.open(‘/content/drive/MyDrive/Herbarium Data-2019/Data/sample/val/’+foldername +‘/’ +filename)
preprocess = transforms.Compose(
[
resize_to,
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # IMPORTANT: normalize for pre-trained models
]
)
input_tensor = preprocess(input_image)
print(“Input Tensor Shape:”, input_tensor.shape)
input_batch = input_tensor.unsqueeze(0)
targets = [ClassifierOutputTarget(6)]
grayscale_cam = cam(input_batch, targets=None,aug_smooth=True)
grayscale_cam = grayscale_cam[0, :]
img=np.array(input_image.resize(resize_size),np.float32)
img = img.reshape(img.shape[1],img.shape[0],img.shape[2])
#print("img shape",img.shape,img.max())
img = img/255
visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
#cam_images = [show_cam_on_image(img, grayscale, use_rgb=True) for img, grayscale in zip(input_image, grayscale_cam)]
visualization = Image.fromarray(visualization)
out_file_name ="/content/drive/MyDrive/Herbarium Data-2019/test222/" +modelname+ "_"+gradvariantname+ "_" + foldername_to_class[os.path.basename(foldername)] + "_" +os.path.basename(filename)
visualization.save(out_file_name)
#print("Visualization saved- now trying to show (GUI mode)")
im = Image.open(out_file_name)
im.show()
image = np.transpose(images.cpu().squeeze().numpy(), (1, 2, 0))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title(f'True label: {labels}, Pred label: {predicted}')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(cam, cmap='jet')
plt.axis('off')
plt.show()
And it is an error: