In this snippet for test and visualizes the output, I need to print the input image name and mask (ground truth) name as well to know which output of prob do they belong to which test image and its ground truth.
#test model
model.eval()
total = 0
test_loss = 0
correct = 0
count = 0
#iterate through test dataset
for ii, data in enumerate(test_loader):
t_image, mask = data
#print(t_image.shape) # torch.Size([1, 1, 240, 320])
t_image, mask = t_image.to(device), mask.to(device)
with torch.no_grad():
outputs = model(t_image)
#print(outputs.shape) # torch.Size([1, 2, 240, 320])
test_loss += criterion(outputs, mask).item() / len(test_loader)
probs = torch.exp(outputs) # get the exp of output and will give the probability map of outpus
# The outputs are energies for the 2 classes.
# Higher the energy for a class, the more the network thinks that the image is of the particular class. So, let’s get the index of the highest energy:
_, predicted = torch.max(outputs.data, 1)
total += mask.nelement()
correct += predicted.eq(mask.data).sum().item()
accuracy = 100 * correct / total
count +=1
print(count, "Test Loss: {:.3f}".format(test_loss), "Test Accuracy: %d %%" % (accuracy))
plt.figure()
plt.subplot(1, 5, 1)
prob_imgs = make_grid(probs.permute(1, 0, 2, 3))
prob_imgs = prob_imgs.detach()
plt.imshow(prob_imgs.permute(1, 2, 0))
plt.subplot(1, 5, 2)
plt.imshow((probs[0,0,:,:]).detach().cpu(), cmap = 'gray')
plt.title('prdicted for class 0')
plt.subplot(1, 5, 3)
plt.imshow((probs[0,1,:,:]).detach().cpu(), cmap = 'gray')
plt.title('prdicted for class 1')
plt.subplot(1, 5, 4)
plt.imshow(mask.detach().cpu().squeeze(), cmap = 'gray')
plt.title('ground truth')
plt.subplot(1, 5, 5)
plt.imshow(t_image.detach().cpu().squeeze(), cmap='gray')
plt.title('input image')
plt.show()