Having trouble with VOC 2012 segmentation with the void = 255 label

Hello, I’m trying to implement a segmentation model using the VOC 2012 data, and I was having trouble with calculating the loss.

I am using cross entropy loss, but then I get the error

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at ..\aten\src\THNN/generic/ClassNLLCriterion.c:92

when I try to set the number of classes to 22 for the output. This is because the data has a label of 255 for void, and thus pytorch thinks there are 256 classes. I was wondering how to rectify this? I loaded the data using torchvision.datasets.VOCSegmentation. Another thing I was not clear about was do people actually try to label the void pixels in practice? Thanks so much

You could pass ignore_index=255 to the initialization of your criterion, if your model does not predict the void class.

2 Likes

That worked! Thanks so much

Oh another issue. How would you recommend computing the confusion matrix to calculate IoU? Would sklearn’s confusion_matrix function take care of this if we feed in a list of all the classes?

sklearn's confusion matrix should most likely work.
If you would like to fill the confusion matrix in your evaluation loop, have a look at this post.

I’ve noticed both the scitkit method and the for loop method is a bit slow. Is there a vectorized way of implementing this or is this just the nature of constructing the confusion matrix? Also the confusion matrix is updated for
every minibatch correct?

Yes, the confusion matrix is updated for each mini batch so avoid storing the values.
Also there is a vectorized (and slightly more complicated way) to fill the confusion matrix using .put_:

# Create dummy data
n_samples = 1000
n_classes = 5

conf_matrix = torch.zeros(n_classes, n_classes)

preds = torch.randn(n_samples, n_classes)
preds = torch.argmax(preds, 1)
labels = torch.randint(0, n_classes, (n_samples,))

# For loop
t0 = time.perf_counter()
for p, t in zip(preds, labels):
    conf_matrix[p, t] += 1
print('Took {}s'.format(time.perf_counter() - t0))

# put_ method
conf_matrix2 = torch.zeros(n_classes, n_classes)

t0 = time.perf_counter()
lin_index = preds * n_classes + labels  # You need to calculate the linear index here
conf_matrix.put_(lin_index, torch.tensor(1.).expand_as(lin_index), accumulate=True)
print('Took {}s'.format(time.perf_counter() - t0))
1 Like