Convert multi-class segmentation mask into integer indexed

I want to use ConfusionMatrix in pytorch_lightning.metrics: pytorch_lightning.metrics.functional.confusion_matrix — PyTorch Lightning 1.1.6 documentation.

I’m working on a multi-class segmentation problem where my mask originally is RGB valued (N, 3, H, W) and I converted it into one-hot (N, #Class, H, W) for training. However, pytorch_lightning.metrics.utils._input_format_classification used in ConfusionMatrix, requires that my mask is integer indexed: (N, H, W). I’m wondering how can I efficiently convert transform an RGB valued mask into integer indexed?

Hi zzzf!

To convert a one-hot-encoded class-label mask into an integer-class-label
mask (e.g., shape = [N, nClass, H, W]shape = [N, H, W]), you
take torch.argmax (one_hot_mask, dim = 1) along the class dimension.

As for your question about how to convert an RGB-valued mask, the issue
is how your class labels are encoded into your color channels (which is
something you haven’t told us). You have said that you can convert RGB
into one-hot, so you could apparently use a two-step procedure where you
first convert RGB to one-hot, and then apply argmax().

Best.

K. Frank

Hi KFrank,

Thanks for your reply!

This two-step is indeed a solution but somewhat not optimal. Is there a way to directly convert the RGB valued mask to inter-indexed mask? Just for example, the class RGB values are (255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 0, 0), (255, 255, 255).

Hi zzzf!

There are a couple of ways of going about this. Probably the most
flexible and most logically straightforward will be to index into a lookup
table with your three-channel colors.

Here is a demonstration script:

import torch
print (torch.__version__)
# convert rgb mask to integer class labels
#   red    =  class-0
#   green  =  class-1
#   blue   =  class-2
#   black  =  class-3
#   white  =  class-4
#   (bad)  =  class-5
# create example mask tensor of shape [nBatch, nRGB = 3, height]
nBatch = 2
height = 4
rgb_mask = 255 * torch.ones ((nBatch, 3, height), dtype = torch.int64)
rgb_mask[0, :, 0]       =    0   # black
rgb_mask[0, [1, 2], 1]  =    0   # red
rgb_mask[0, [0, 2], 2]  =    0   # green
rgb_mask[0, [0, 1], 3]  =    0   # blue
rgb_mask[1, :, 0]       =  255   # white
rgb_mask[1, :, 1]       =    0   # black
rgb_mask[1, [0, 1], 2]  =    0   # blue
rgb_mask[1, [1, 2], 3]  =    0   # red
print (rgb_mask.shape)
print (rgb_mask)
# build lookup table
lut_dim = torch.tensor (rgb_mask.shape)
lut_dim[1] = -1
lut = torch.tensor ([3, 0, 1, 5, 2, 5, 5, 4])  # color-class encoding
lut = lut.unsqueeze (-1).expand (lut_dim.tolist())
# convert rgb colors to 0-7 indices
label_mask = torch.sign (rgb_mask)
powers2 = torch.tensor ([2])**torch.arange (3)
label_mask = (powers2 * label_mask.transpose (1, -1)).transpose (1, -1).sum (dim = 1, keepdim = True)
# index into lut to get integer class labels
label_mask = torch.gather (lut, 1, label_mask).squeeze()
print (label_mask.shape)
print (label_mask)

And here is its output:

1.7.1
torch.Size([2, 3, 4])
tensor([[[  0, 255,   0,   0],
         [  0,   0, 255,   0],
         [  0,   0,   0, 255]],

        [[255,   0,   0, 255],
         [255,   0,   0,   0],
         [255,   0, 255,   0]]])
torch.Size([2, 4])
tensor([[3, 0, 1, 2],
        [4, 3, 2, 0]])

Best.

K. Frank

1 Like