Convert pixel wise class tensor to image segmentation

I have two networks which both output a tensor of shape batch_size x num_classes x height x width with num_classes = 256 (there are actually just 21 classes in VOC12 but they choose background to have label 255 - I will improve on this later).

So as the label has format batch_size x 1 x height x width I can calculate cross entropy loss with:

criterion = nn.CrossEntropyLoss()

# shape: batch_size x 1 x height x width = 1 x 1 x 256 x 256
inputs = autograd.Variable(images)
# shape: batch_size x 1 x height x width = 1 x 1 x 256 x 256
targets = autograd.Variable(labels)
# shape: batch_size x 1 x height x width = 1 x 256 x 256 x 256
outputs = model(inputs)

optimizer.zero_grad()
loss = criterion(outputs.view(-1, 256), targets.view(-1))
loss.backward()
optimizer.step()

See source for context.

I know that outputs[0][0][i][j| corresponds to the probability that the pixel at (i, j) belongs to class 1. So if want to transform outputs of shape 1 x 256 x 256 x 256 to 1 x 1 x 256 x 256 I would need to find the maximum (probability) of every pixel and assign it to the corresponding class value.

I could do this manually by iterating over every class and pixel with numpy but I wonder if there is any better way using tensor operations?

3 Likes

Note that some losses accept 2D inputs to them (and CrossEntropy will be updated soon as well to support it). So a more efficient way of computing the loss would be something like

nllcrit = nn.NLLLoss2d(size_average=True) # need to write a functional interface for it
def criterion(input, target):
    return nllcrit(F.log_softmax(input), target)

Now, if you want to compute the confusion matrix for your predictions, you can use a variant of ConfusionMeter from tnt, but replacing the input by something like

output = output.permute(0, 2, 3, 1).view(-1, ncls).squeeze()
target = target.view(-1).squeeze()

Thanks for the hint with ConfusionMeter. I think outputs.data[0].numpy().argmax(0) does what I need.

In torch, max and argmax are computed together and returned as a tuple. So outputs.max(0)[1] is the native torch way to do this.

3 Likes

Any ideas what could be wrong?

Code is here.

This happens with both models (UNet and 1-layer conv) so I guess there must be something wrong with loss or optimization.

@fmassa Does the F.log_softmax take care of the fact that we need to take softmax along the 1st axis?

For eg the input would be a tensor of size (batch_size, n_classes, H, W), we need to apply softmax along each n_classes slice for each pixel in the image sized h*w.

From what I read here, I’m not sure if F.log_softmax does that.

I’m currently permuting and resizing the outputs to get a slice with dimensions batch_size, h, w, n_classes and then directly using F.cross_entropy. Is this approach correct?

            # outputs.shape =(batch_size, n_classes, img_cols, img_rows) 
            outputs = outputs.permute(0, 2, 3, 1)
            # outputs.shape =(batch_size, img_cols, img_rows, n_classes) 
            outputs = outputs.resize(batch_size*img_cols*img_rows, n_classes)
            labels = labels.resize(batch_size*img_cols*img_rows)
            loss = F.cross_entropy(outputs, labels)
1 Like

Yes, F.log_softmax will apply softmax along the 1st axis if your input if 4D

1 Like