Ignite ConfusionMatrix Metric Error

I have a working ignite setup - precision, recall accuracy all work. However, I am getting a bizarre error when I try to add in ConfusionMatrix(num_classes=1). Note the batch size is 5, and this pertains to a binary classifier.

/usr/local/lib/python3.7/site-packages/ignite/metrics/confusion_matrix.py in update(self, output)
92 @reinit__is_reduced
93 def update(self, output: Sequence[torch.Tensor]) -> None:
—> 94 self._check_shape(output)
95 y_pred, y = output
96

/usr/local/lib/python3.7/site-packages/ignite/metrics/confusion_matrix.py in _check_shape(self, output)
78 "y_pred must have shape (batch_size, num_categories, …) and y must have "
79 "shape of (batch_size, …), "
—> 80 “but given {} vs {}.”.format(y.shape, y_pred.shape)
81 )
82

ValueError: y_pred must have shape (batch_size, num_categories, ...) and y must have shape of (batch_size, ...), but given torch.Size([5, 1]) vs torch.Size([5, 1]).

As you can see, the y_pred and y shapes match…any ideas what might be going on here?

I agree that the error message

y_pred must have shape (batch_size, num_categories, ...) and y must have shape of (batch_size, ...), but given torch.Size([5, 1]) vs torch.Size([5, 1])

is not clear enough as it intends to say the following:

  • y_pred must have shape (batch_size, num_categories, …) and given torch.Size([5, 1]) = OK
  • y must have shape of (batch_size, …) and given torch.Size([5, 1]) => Wrong

y should have the shape : (5, ) without 1

Could you please open an issue to improve the docs please on the github repository ?

Thanks, that did the trick. Any idea what the output of ConfusionMatrix means for num_classes=1?

The documentation is thin on this as well, and it returns a 1x1 tensor / single value (I expected a 2x2 matrix). I can’t make sense of it based on my precision/recall.

Yes, I agree that the docs should be reworked.
If you would like to do binary classification, please set num_classes=2.
Values of confusion matrix can be by average option to match precision, recall or number of samples…

Thanks, new issue opened on GitHub here:

1 Like