Convert segmentation mask of shape [224,224,3] to mask [224,224,classes]

Hi I am having problem while converting rgb mask of shape [224,224,3] to mask of shape [224,224,3]. I have attached the code below.

I am getting masks of shape [224,224,classes] but lose information of classes in channels only one channel has some mask will others don’t


class CamVid_Dataset():

   def __init__(self, img_pth, mask_pth, transform):

      self.img_pth = img_pth
     self.mask_pth = mask_pth
     self.transform = transform
     all_imgs = os.listdir(self.img_pth)
     all_masks = os.listdir(self.mask_pth)
     self.total_imgs = natsort.natsorted(all_imgs)
     self.total_masks = natsort.natsorted(all_masks)

def __len__(self):
     return len(self.total_imgs)

def __getitem__(self, idx):
     img_loc = os.path.join(self.img_pth, self.total_imgs[idx])
     image ="RGB")
     tensor_image = self.transform(image)

     mask_loc = os.path.join(self.mask_pth, self.total_masks[idx])
     mask ="RGB")
     tensor_mask = self.transform(mask)
     tensor_mask = rgb_to_mask(np.array(tensor_mask).transpose(1,2,0), id2code)
     return tensor_image, tensor_mask

#Define transforms for the training data and validation data

train_transforms = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor()])

Pass transform here-in

train_data = CamVid_Dataset(img_pth = path + 'train/', mask_pth = path + 'train_labels/', transform = train_transforms)

Data loaders

trainloader =, batch_size=batch_size, shuffle=False)

inputs, mask = next(iter(trainloader))


def rgb_to_mask(img, color_map):


    Converts a RGB image mask of shape [batch_size, h, w, 3] to Binary Mask of shape [batch_size, classes, h, w]


        img: A RGB img mask

        color_map: Dictionary representing color mappings


        out: A Binary Mask of shape [batch_size, classes, h, w]


     num_classes = len(color_map)
     shape = img.shape[:2]+(num_classes,)
     out = np.zeros(shape, dtype=np.int8)
     for i, cls in enumerate(color_map):
         out[:,:,i] = np.all(img.reshape( (-1,3) ) == color_map[i], axis=1).reshape(shape[:2])
     return out

I assume your mask images contain a specific color code, which would be used to create the corresponding class indices.
If that’s the case, be careful with transforms.Resize((224,224)), as the default interpolation would be a bilinear interpolation, which could corrupt some mask colors.
To resize a mask image you should this use PIL.Image.NEAREST.
Let me know, if this helps.

1 Like

Hi, Thanks for responding.

  • Well I tried what you suggested but didn’t work. But it seems that using transform.PILToTensor (for masks) instead of transforms.Tensor works fine as the masks don’t lose information then. Can you tell me why ?
  • Another thing I am working on CamVid dataset with 32 classes. Initial mask provided was of size [h,w,3] I converted it into [h,w,32] where each channel represent corresponding class. Is this approach correct ?
  • Also, I am not sure about if I should use Sigmoid or Softmax function. And Currently I am using loss BCE
    The prediction on test is not good


def Double_Conv(input_channel, output_channel):
     return nn.Sequential(
          nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1),
          nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1),

class UNet(nn.Module):
      def __init__(self, no_of_cls):

        self.down1 = Double_Conv(3, 64)
        self.down2 = Double_Conv(64, 128)
        self.down3 = Double_Conv(128, 256)
        self.down4 = Double_Conv(256, 512)
        self.down5 = Double_Conv(512, 1024)

        self.max_pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.up4 = Double_Conv(512 + 1024, 512)
        self.up3 = Double_Conv(256 + 512, 256)
        self.up2 = Double_Conv(128 + 256, 128)
        self.up1 = Double_Conv(64 + 128, 64)

        self.lastConv = nn.Conv2d(64, no_of_cls, 1)

def forward(self, x):
  conv1 = self.down1(x)
  x = self.max_pool(conv1)

  conv2 = self.down2(x)
  x = self.max_pool(conv2)

  conv3 = self.down3(x)
  x = self.max_pool(conv3)

  conv4 = self.down4(x)
  x = self.max_pool(conv4)

  x = self.down5(x)

  x = self.upsample(x)
  x =[x, conv4], dim=1)

  x = self.up4(x)
  x = self.upsample(x)
  x =[x, conv3], dim=1)

  x = self.up3(x)
  x = self.upsample(x)
  x =[x, conv2], dim=1)

  x = self.up2(x)
  x = self.upsample(x)
  x =[x, conv1], dim=1)

  x = self.up1(x)
  out = self.lastConv(x)
  return torch.sigmoid(out)
  1. I’m not aware of the transforms.PILToTensor or transforms.Tensor methods. Do you mean transforms.ToTensor? If so, what are you comparing it against?

  2. That’s incorrect for a multi-class segmentation, where each pixel belongs to one specific class only. In that case the target tensor should have the shape [batch_size, height, width] and contain the class indices in the range [0, nb_classes-1]. Since your tensor seems to be a one-hot encoded mask, you could simply call target = one_hot_target.argmax(2).

  3. For a multi-class segmentation you should use nn.CrossEntropyLoss, remove the last sigmoid and pass the raw logits in the shape [batch_size, nb_classes, height, width] to the criterion.

  1. Sorry, not transforms.Tensor it is transform.ToTensor to convert PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
  2. Can you explain a bit about how to deal with masks in case of multi-class segmentation. I am confused about this particularly. Currently my each of 32 channel has pixel value 0-1. and each channel of mask shows a particular class.
  3. I’ll do that

Since your current mask is a one-hot encoded tensor (each channel represents a class, where 1 denotes an active class), you could transform it into the desired class mask via:

target = torch.argmax(mask, dim=1)

Note that dim=1 is used, if your channel dimension (class channels) is in dim1.

Yes, thank you. Finally, I got it.
Can you tell me how can I improve performance of my Network. These are the current results that I a getting.

I’m still a bit confused about the difference between torchvision.transforms.Resize and torch.functional.interpolate for masks (assuming masks have uint8 format and contains only integers, obviosuly).

The problem is, I kept getting suspicious results using Mask R-CNN:

mask = np.array(, resample=PILImage.NEAREST))

It seems that transforms.Resize is identical to PIL.Image.Resize. On the other hand, functional.interpolate(mode=NEAREST, align_corners=True) uses opencv or some other solution, and in fact seems to give the correct resizing. I think so because torchvision.detection.transforms uses this method both for the image (mode=bilinear, align_corners=False) and masks (mode=nearest, align_corneres=True).

So I don’t quite understand either the difference, nor the correct use of transforms.Resize for masks, apart from the fact that transforms.Resize is applied to PIL.Image and functional.interpolate to tensors…

There might be differences in the interpolation results between PIL's implementation and the native one (I haven’t checked it recently). To make sure the same algorithm is used, you could stick to one library for both, the input and mask.

1 Like

Thanks, does it mean that F.interpolate uses an algorithm different to PIL.Image.resize? I looked at the source code, but couldn’t identify any.

Does this also mean that if I don’t use masks (e.g. for classification), I can stick to transforms.Resize?

It could use a different algorithms and you could compare the outputs for different interpolation methods. If I remember it correctly, at least OpenCV and PIL used different methods and I guess that the internal PyTorch method might have used one of them.

1 Like

Thanks. My main concern was related to masks, not images, as I got the impression algorithms other that F.interpolate produced suspicious results, incl. transforms.Resize (also numpy.resize, PILImage.resize, etc)

Which concern do you have about the mentioned methods and how do these suspicious results look like or what makes them suspicious?

1 Like

Other methods somehow returned other pixel values, e.g. [3,4] instead of [4]. It’s algood now.