I am using this code to do experiments on MNIST. I am wondering how to save the images receiving wrong predictions and the wrong predicted results (like a 7 predicted as 1). Thank you!
You could adapt the test code to check all wrong predictions and store the passed images as shown here:
import torchvision.transforms.functional as TF
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
# Store wrongly predicted images
wrong_idx = (pred != target.view_as(pred)).nonzero()[:, 0]
wrong_samples = data[wrong_idx]
wrong_preds = pred[wrong_idx]
actual_preds = target.view_as(pred)[wrong_idx]
for i in range(wrong_idx):
sample = wrong_samples[i]
wrong_pred = wrong_preds[i]
actual_pred = actual_preds[i]
# Undo normalization
sample = sample * 0.3081
sample = sample + 0.1307
sample = sample * 255.
sample = sample.byte()
img = TF.to_pil_image(sample)
img.save('wrong_idx{}_pred{}_actual{}.png'.format(
wrong_idx[i], wrong_pred.item(), actual_pred.item()))
Thank you! Can I ask a further question, what if I want to save the indices of the wrong predicted samples in the whole test loader. Actually I want to compare the performance of two networks, i.e. which samples predicted wrong by one network are correctly predicted by the other network.
In that case you could write a custom Dataset
, deriving from the MNIST
dataset, and return the sample index with the actual data and target sample.
In the provided code snippet, you would get the additional class index and could store it for further processing.
I am getting an error in the line for i in range(wrong_idx):
It says : TypeError: only integer scalar arrays can be converted to a scalar index
Try to use for i in range(len(wrong_idx)):
instead.