How to save wrong prediction results of a CNN on MNIST

I am using this code to do experiments on MNIST. I am wondering how to save the images receiving wrong predictions and the wrong predicted results (like a 7 predicted as 1). Thank you!

You could adapt the test code to check all wrong predictions and store the passed images as shown here:

import torchvision.transforms.functional as TF

def test(args, model, device, test_loader):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target =,
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            # Store wrongly predicted images
            wrong_idx = (pred != target.view_as(pred)).nonzero()[:, 0]
            wrong_samples = data[wrong_idx]
            wrong_preds = pred[wrong_idx]
            actual_preds = target.view_as(pred)[wrong_idx]

            for i in range(wrong_idx):
                sample = wrong_samples[i]
                wrong_pred = wrong_preds[i]
                actual_pred = actual_preds[i]
                # Undo normalization
                sample = sample * 0.3081
                sample = sample + 0.1307
                sample = sample * 255.
                sample = sample.byte()
                img = TF.to_pil_image(sample)
                    wrong_idx[i], wrong_pred.item(), actual_pred.item()))
1 Like

Thank you! Can I ask a further question, what if I want to save the indices of the wrong predicted samples in the whole test loader. Actually I want to compare the performance of two networks, i.e. which samples predicted wrong by one network are correctly predicted by the other network.

In that case you could write a custom Dataset, deriving from the MNIST dataset, and return the sample index with the actual data and target sample.
In the provided code snippet, you would get the additional class index and could store it for further processing.

1 Like

I am getting an error in the line for i in range(wrong_idx):

It says : TypeError: only integer scalar arrays can be converted to a scalar index

Try to use for i in range(len(wrong_idx)): instead.