Multiclass Segmentation

Hi, is there an example for creating a custom dataset and training for multiclass segmentation using U-Net? I find many examples for binary segmentation but yet to find something for multiclass segmentation. Thank you!

I assume you have already found suitable code snippets for a binary segmentation use case?
If so, you could use it as a base line and make a few changes for a multi class segmentation use case:

  • use nn.CrossEntropyLoss as your criterion
  • your model should output a tensor with the shape [batch_size, nb_classes, height ,width]
  • the target should be a LongTensor with the shape [batch_size, height, width] and contain the class indices for each pixel location in the range [0, nb_classes-1]

Depending on the format of your segmentation mask images, you might need to create a mapping e.g. between color codes and the corresponding class indices.

Let us know, if and where you get stuck.

6 Likes

Thank you very much for the reply. I have 2 follow up questions -
1 - in binary, the output of the U-net model will is a softmax(). Will it be the same for multi-class?
2 - my target has 3 classes. So, I should have 0 for background and 1,2,3 for the three classes at each pixel location, right?

  1. If you are using nn.BCELoss, the output should use torch.sigmoid as the activation function. Alternatively, you won’t use any activation function and pass raw logits to nn.BCEWithLogitsLoss. If you use nn.CrossEntropyLoss for the multi-class segmentation, you should also pass the raw logits without using any activation function.

  2. Yes, but then you should deal with 4 classes (background + 3 classes), so the output of your model should be [batch_size, 4, h, w].

3 Likes

I have a follow up question on this-
1)Do we need to apply any normalisation on the masks?

Following is my custom data loader for image segmentation-

class DataLoaderSegmentation(data.Dataset):
    def __init__(self,folder_path,transform = None):
        super(DataLoaderSegmentation, self).__init__()
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.tif'))
        self.mask_files = glob.glob(os.path.join(folder_path,'mask','*.bmp'))
        self.transform = transform
    def __getitem__(self, index):
        img_path = self.img_files[index]
        mask_path = self.mask_files[index]
        data = Image.open(img_path)
        label = Image.open(mask_path)
        if self.transform:
            data = self.transform(data)
        label = np.array(label)
        return data, torch.from_numpy(label).long()
    
    def __len__(self):
        return len(self.img_files)`Preformatted text`

Also when I don’t use transform.ToTensor normalisation but rather use torch.from_numpy(label).long(), the shape of target comes out to be Batch_size,height,width,num_channels
2)Why the difference and how to solve this problem?
3)Also, I couldn’t understand this line much Depending on the format of your segmentation mask images, you might need to create a mapping e.g. between color codes and the corresponding class indices.

Most likely you should not apply any normalization on your segmentation masks, as this will distort the class indices.
Could you print the shape of label before passing it to torch.from_numpy?
I would assume the channel in in dim0 or your images don’t have the channel dimension, if you are loading them with PIL.

Have a look at this post where I’ve explained the mapping a bit better.
Basically your targets should contain class indices in [0, nb_classes].
However, sometimes your segmentation images use a color code for certain classes, e.g. red could be a car and blue could be a tree. Using a mapping, you would have to transform these color codes to class indices, e.g. red->0, blue->1, …

PS: If you are resizing the mask images, make sure to use nearest neighbor interpolation, as other interpolation techniques might distort the labels/colors.

1 Like

Yes, when I print the shape of image after converting it into numpy array, even reading from both cv2 and PIL, the output shape is 3*512*512 which I corrected using np.transpose(imgs,(2,0,1)).

I got your point regarding the mapping but how to define the mapping for RGB images, my masks have 4 classes- Blue,Red,Green and black background.

I am not resizing the mask images but I am Random Cropping the original image to a smaller size and my final output segmented image is same size as of cropped input image.Is it the right way to do it, are there any performance issues with it?

Have a look at this post to see, how to create a custom mapping.

RandomCrop should work in the default setup without padding. However, you should make sure to use the same “random” locations for both, your input and target as shown here.

okay, I followed the post and implemented the following dataloader class keeping in mind the same random locations for both input and target.
First, I want to show the required transformation by me which I have applied in a normal way if I don’t need to apply them simultaneously-

data_transforms = transforms.Compose([transforms.RandomCrop((512,512)),
                                 transforms.Lambda(gaussian_blur),
                                 transforms.Lambda(elastic_transform),
                                 transforms.RandomRotation([+90,+180]),
                                 transforms.RandomRotation([+180,+270]),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=train_mean, std=train_std)
                               ])


Now, the thing is I can apply RandomCrop, Random Rotation, ToTensor in the class but I have no idea how to apply gaussian_blur,elastic_transform and Normalization with custom mean
Also, below is my dataloader class

class DataLoaderSegmentation(data.Dataset):
    def __init__(self,folder_path):
        super(DataLoaderSegmentation, self).__init__()
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.tif'))
        self.mask_files = glob.glob(os.path.join(folder_path,'mask','*.bmp'))
  
    def mask_to_class(self,mask):
        target = torch.from_numpy(mask)
        h,w = target.shape[0],target.shape[1]
        masks = torch.empty(h, w, dtype=torch.long)
        colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        for k in mapping:
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3) 
            masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
        return masks
    
    def transform(self,image,mask):
        i, j, h, w = transforms.RandomCrop.get_params(
        image, output_size=(512, 512))
        image = TF.crop(image, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        
        image = TF.rotate(image,90)
        mask = TF.rotate(mask,90)
        image = TF.rotate(image,180)
        mask = TF.rotate(mask,180)
        image = TF.rotate(image,270)
        mask = TF.rotate(mask,270)

        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        return image, mask
    
    def __getitem__(self, index):
        img_path = self.img_files[index]
        mask_path = self.mask_files[index]
        data = Image.open(img_path)
        label = Image.open(mask_path)
        data,label = self.transform(data,label)
        label = np.array(label)
        mask = self.mask_to_class(label)
        return data,mask
           
    def __len__(self):
        return len(self.img_files)

I am obtaining the following error -

RuntimeError: Assertion `input0 == target0 && input2 == target1 && input3 == target2' failed. size mismatch (got input: 5x4x504x504, target: 5x584x565) at ../aten/src/THNN/generic/SpatialClassNLLCriterion.c:59

Because size of my label is now batch_size*height*width, so does this implies I have to do certain change in my network too or is someplace else I am screwing up?

The number of dimensions look alright. However the spatial size is different.
While your data is [504 x 504], your target is [584 x 565], which will throw this error.

Also, be careful about using TF.to_tensor on your mask, as this might normalize the class indices into float values in the range [0, 1].

1 Like

Thanks a lot, there was a problem in the architecture which I resolved.
Yes, I removed TF.to_tensor on the mask though yet to figure out how to do selective transformations on it.
Although I am encountering the following error on running the current network-

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at/aten/src/THNN/generic/SpatialClassNLLCriterion.c:109

Now, I am not able to figure out where I am doing wrong?

Make sure the target has the shape [N, H, W] and contains class indices in the range [0, nb_classes-1], while your output should have the shape [N, nb_classes, H, W].
Apparently some target tensor contains values outside of this range.
You could use a print statement while iterating your DataLoader to track down the problematic batch.

from future import print_function
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import glob
import skimage.io as io
import skimage.transform as trans

def adjustData(img,mask,flag_multi_class,num_class):
if(flag_multi_class):
img = img / 255
mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
new_mask = np.zeros(mask.shape + (num_class,))
for i in range(num_class):
#for one pixel in the image, find the class in mask and convert it into one-hot vector
#index = np.where(mask == i)
#index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
#new_mask[index_mask] = 1
new_mask[mask == i,i] = 1
new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2]))
mask = new_mask
elif(np.max(img) > 1):
img = img / 255
mask = mask /255
mask[mask > 0.5] = 1
mask[mask <= 0.5] = 0
return (img,mask)

def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = “grayscale”,
mask_color_mode = “grayscale”,image_save_prefix = “image”,mask_save_prefix = “mask”,
flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
‘’’
can generate image and mask at the same time
use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
if you want to visualize the results of generator, set save_to_dir = “your path”
‘’’
image_datagen = ImageDataGenerator(**aug_dict)
mask_datagen = ImageDataGenerator(**aug_dict)
image_generator = image_datagen.flow_from_directory(
train_path,
classes = [image_folder],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
seed = seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes = [mask_folder],
class_mode = None,
color_mode = mask_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = mask_save_prefix,
seed = seed)
train_generator = zip(image_generator, mask_generator)
for (img,mask) in train_generator:
img,mask = adjustData(img,mask,flag_multi_class,num_class)
yield (img,mask)

def valGenerator(batch_size,train_path,image_folder,mask_folder,image_color_mode = “grayscale”,
mask_color_mode = “grayscale”,image_save_prefix = “image”,mask_save_prefix = “mask”,
flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
‘’’
can generate image and mask at the same time
use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
if you want to visualize the results of generator, set save_to_dir = “your path”
‘’’
image_datagen = ImageDataGenerator()
mask_datagen = ImageDataGenerator()
image_generator = image_datagen.flow_from_directory(
train_path,
classes = [image_folder],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
seed = seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes = [mask_folder],
class_mode = None,
color_mode = mask_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = mask_save_prefix,
seed = seed)
train_generator = zip(image_generator, mask_generator)
for (img,mask) in train_generator:
img,mask = adjustData(img,mask,flag_multi_class,num_class)
yield (img,mask)

def testGenerator(test_path,num_image = 18,target_size = (256,256),flag_multi_class = False,as_gray = True):
for i in test_path:
img = io.imread(i, as_gray = as_gray)
print(i)

img = img / 255

    img = trans.resize(img,target_size)
    img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
    img = np.reshape(img,(1,)+img.shape)
    yield img

def testGenerator2(batch_size,test_path,target_size = (256,256), image_color_mode = “grayscale”, num_class = 2,flag_multi_class = False,as_gray = True):
image_datagen = ImageDataGenerator()
image_generator = image_datagen.flow_from_directory(
test_path,
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size
)
return image_generator

def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = “image”,mask_prefix = “mask”,image_as_gray = True,mask_as_gray = True):
image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix))
image_arr = []
mask_arr = []
for index,item in enumerate(image_name_arr):
img = io.imread(item,as_gray = image_as_gray)
img = np.reshape(img,img.shape + (1,)) if image_as_gray else img
mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask
img,mask = adjustData(img,mask,flag_multi_class,num_class)
image_arr.append(img)
mask_arr.append(mask)
image_arr = np.array(image_arr)
mask_arr = np.array(mask_arr)
return image_arr,mask_arr

def labelVisualize(num_class,color_dict,img):
img = img[:,:,0] if len(img.shape) == 3 else img
img_out = np.zeros(img.shape + (3,))
for i in range(num_class):
img_out[img == i,:] = color_dict[i]
return img_out

def labelVisualizeBinary(img):
img[img > 0.5] = 1
img[img<= 0.5] = 0
return img

def saveResult(save_path,results_filename,npyfile,flag_multi_class = False,num_class = 2):
print(results_filename[0])
for i,item in enumerate(npyfile):
img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]
print(i)
io.imsave(os.path.join(save_path,results_filename[i]),img)

my problem

this is my data code , i have problem how to change it to multiclass segmentation,
i divided codes with 3 part , 1 for data(to transform (num classes), 1 for train(like epoch,path_to_train,etc) , and 1 for model (unet)

in my model i was change the loss function to sparse_categorical_crossentropy, and the last layer to softmax,i know my problem in the data codes (how to transform it multiclass (num classes), i has 3 output (black,white, red) that mean my num_classes = 3), but i dont know how to change it? i just change the num_classes but nothing happend , maybe you can solve my problem thank you :slight_smile: (codes up here) thank you

It seems you are using Keras not PyTorch, so I think you will get a better answer on StackOverflow.
I’m not familiar enough with the data loading in Keras. :confused:

okay , sir do you have complete codes multiclass segmentation ? (data,model,train) ? can you give it to me? thank you sir :slight_smile:

Could you please explain how there are 4 channels(classes) in output. Because if we have 3 channels in output and if 3 channels are empty(that means no segmentation).We can automatically consider this case as background.right?So it would work with 3 channels.ie, output can be [batch_size, 3, h, w] for 3 classes + BG class. right? please correct me if iam wrong.

This wouldn’t work for a multi-class segmentation with nn.CrossEntropyLoss or nn.NLLLoss, since the target has to contain a class index.
With three classes the class indices would be [0, 1, 2] and that for each pixel one of these classes would be active.

If you consider your use case a multi-label segmentation, you could use nn.BCEWithLogitsLoss, which would allow you to define zero, one, or more active classes for each pixel.
This approach would work for 3 explicit channels, as no active class could be considered the “background class”. However, since each pixel could also have all classes set to active, this would be a multi-label segmentation.