Confusion matrix

To calculate the confusion matrix you need the class predictions. Currently it looks like pred contains the logits or probabilities for two classes.
Try to call torch.argmax(pred, 1) to get the predicted classes.
Here is a small example:

output = torch.randn(1, 2, 4, 4)
pred = torch.argmax(output, 1)
target = torch.empty(1, 4, 4, dtype=torch.long).random_(2)
confusion_matrix(pred.view(-1), target.view(-1))
2 Likes