# Convert multi-class segmentation mask into integer indexed

I want to use ConfusionMatrix in pytorch_lightning.metrics: pytorch_lightning.metrics.functional.confusion_matrix — PyTorch Lightning 1.1.6 documentation.

I’m working on a multi-class segmentation problem where my mask originally is RGB valued (N, 3, H, W) and I converted it into one-hot (N, #Class, H, W) for training. However, pytorch_lightning.metrics.utils._input_format_classification used in ConfusionMatrix, requires that my mask is integer indexed: (N, H, W). I’m wondering how can I efficiently convert transform an RGB valued mask into integer indexed?

Hi zzzf!

To convert a one-hot-encoded class-label mask into an integer-class-label
mask (e.g., `shape = [N, nClass, H, W]``shape = [N, H, W]`), you
take `torch.argmax (one_hot_mask, dim = 1)` along the class dimension.

is how your class labels are encoded into your color channels (which is
something you haven’t told us). You have said that you can convert RGB
into one-hot, so you could apparently use a two-step procedure where you
first convert RGB to one-hot, and then apply `argmax()`.

Best.

K. Frank

Hi KFrank,

This two-step is indeed a solution but somewhat not optimal. Is there a way to directly convert the RGB valued mask to inter-indexed mask? Just for example, the class RGB values are (255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 0, 0), (255, 255, 255).

Hi zzzf!

flexible and most logically straightforward will be to index into a lookup

Here is a demonstration script:

``````import torch
print (torch.__version__)
# convert rgb mask to integer class labels
#   red    =  class-0
#   green  =  class-1
#   blue   =  class-2
#   black  =  class-3
#   white  =  class-4
# create example mask tensor of shape [nBatch, nRGB = 3, height]
nBatch = 2
height = 4
rgb_mask = 255 * torch.ones ((nBatch, 3, height), dtype = torch.int64)
rgb_mask[0, :, 0]       =    0   # black
rgb_mask[0, [1, 2], 1]  =    0   # red
rgb_mask[0, [0, 2], 2]  =    0   # green
rgb_mask[0, [0, 1], 3]  =    0   # blue
rgb_mask[1, :, 0]       =  255   # white
rgb_mask[1, :, 1]       =    0   # black
rgb_mask[1, [0, 1], 2]  =    0   # blue
rgb_mask[1, [1, 2], 3]  =    0   # red
# build lookup table
lut_dim[1] = -1
lut = torch.tensor ([3, 0, 1, 5, 2, 5, 5, 4])  # color-class encoding
lut = lut.unsqueeze (-1).expand (lut_dim.tolist())
# convert rgb colors to 0-7 indices
powers2 = torch.tensor ([2])**torch.arange (3)
label_mask = (powers2 * label_mask.transpose (1, -1)).transpose (1, -1).sum (dim = 1, keepdim = True)
# index into lut to get integer class labels
``````

And here is its output:

``````1.7.1
torch.Size([2, 3, 4])
tensor([[[  0, 255,   0,   0],
[  0,   0, 255,   0],
[  0,   0,   0, 255]],

[[255,   0,   0, 255],
[255,   0,   0,   0],
[255,   0, 255,   0]]])
torch.Size([2, 4])
tensor([[3, 0, 1, 2],
[4, 3, 2, 0]])
``````

Best.

K. Frank

1 Like