Is there a better way to find wrong test results?

I am just starting with pytorch, and following the MNist CNN example code I found. The test method looks like this:

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

And so I wanted to display the wrong results. After some stumbling through documentation, I arrived with this solution:

            z = torch.zeros_like(pred)
            o = torch.ones_like(pred)
            ne = pred.ne(target.view_as(pred))
            w = torch.where(pred.ne(target.view_as(pred)), o, z)
            wrong = torch.nonzero(w).squeeze()[:,0]
            wrongdata = torch.index_select(data, 0, wrong)
            plt.figure(figsize=(15,15))
            plt.subplot(1,2,1)
            plt.axis("off")
            plt.title("Real Images")
            deviceData = wrongdata.to(device)[:64]

            plt.imshow(np.transpose(vutils.make_grid(deviceData, padding=5, normalize=True).cpu(),(1,2,0)))
            plt.show()

I think that getting all that data could be done more elegantly/simply, but I am not sure how. Could someone help me out?

Below is a quicker way to get this done. I have considered a binary classification problem. This can be easily extended to a multi class classification problem.

import numpy as np
import torch
data = torch.randn(64,224,224)
target = torch.from_numpy(np.random.randint(0,1,64))
predicted = torch.from_numpy(np.random.randn(64,2))
pred = predicted.argmax(dim = 1) ## Note that keepdim is false
### This line would do the trick of obtaining incorrect data
incorrect_data = data[pred!=target,:,:]