I am just starting with pytorch, and following the MNist CNN example code I found. The test method looks like this:
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
And so I wanted to display the wrong results. After some stumbling through documentation, I arrived with this solution:
z = torch.zeros_like(pred)
o = torch.ones_like(pred)
ne = pred.ne(target.view_as(pred))
w = torch.where(pred.ne(target.view_as(pred)), o, z)
wrong = torch.nonzero(w).squeeze()[:,0]
wrongdata = torch.index_select(data, 0, wrong)
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
deviceData = wrongdata.to(device)[:64]
plt.imshow(np.transpose(vutils.make_grid(deviceData, padding=5, normalize=True).cpu(),(1,2,0)))
plt.show()
I think that getting all that data could be done more elegantly/simply, but I am not sure how. Could someone help me out?