Multi class segmentation

Hello. I am dealing with the multi-class segmentation.
I used to handle the binary class for semantic segmentation.
In the binary, I use the binary mask as the target.
However in the multi-class, it looks like I need some change.

This is my mask.
I have 5 classes which are Red, Green,Blue, white and black.
My model output is the 5 channel.
I use this 5 classes with 3 channel image as 1 channel image use below code.


After that I run the below train loop.
Because I read some article and they said I have to use one hot encoding for cross entropy loss.

        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def run_train_loop(self, epochs):
        # Run training
        for epoch in range(epochs):
            print('Epoch {}/{}'.format(epoch + 1, epochs))
            print('-' * 10)
            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            print("Learning rate: " + str(lr))
            running_loss = 0.0
            running_score = 0.0

            for batch_idx, (images, masks) in enumerate(self.train_data_loader):
                # Obtain batches
                images, masks =,
                output = model(images)
                loss = self.criterion(output, masks.long())

However, my loss is 0 after some epoch. I have 2000 images and mask. So I think it is not the lack of data problem. I am not using the softmax function but I read some answer in this forum that it has already included in cross entropy function. If not I have to include it of course.
Also I am not sure that it is the correct way that the ground truth channel is 1 channel only. Becuase output of the model is 5 channel. Do you think I should change the my mask as 5 channel with 0 and 1? Or should my model output is the 5 channels, because 0 is background.

Also I am not sure that what is class mapping? Does it just decode the segmentation output for visualization? Or how can I put class information to the training process?
I am little bit confuse about it.
Thank you.

Nom you shouldn’t use one-hot encoded target tensors for nn.CrossEntropyLoss, as it expects a model output in the shape [batch_size, nb_classes, height, width] and a target tensor in the shape of [batch_size, height, width] containing the class indices in the range [0, nb_classes-1] for a multi-class segmentation use case.

For 4 valid classes + a background class, the number of classes should be 5.

A class mapping might be a method which maps the different discrete color codes from the target image to a class index, e.g. red -> class0, green -> class1, blue -> class2, pink -> class3, …

1 Like

Thank you for your answer. And sorry for the late response. My alarm just pop up after 6 days.
I do not know why.
for the first question. I use the code like Camvid dataset and solve the issue.

def get_class_weights(num_classes, c=1.02):
pipe = loader(’./datasets/CamVid/train1/’, ‘./datasets/CamVid/trainannot1/’, batch_size=‘all’)
_, labels = next(pipe)
all_labels = labels.flatten()
each_class = np.bincount(all_labels, minlength=num_classes)
prospensity_score = each_class / len(all_labels)
class_weights = 1 / (np.log(c + prospensity_score))
return class_weights

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))

I hope this is correct. It looks like somewhat work also.
Thank you.

It seems as if the class_weights might only be calculated from a single batch?
If that’s the case, the weights might be quite off and you should try to grab at least a few batches to approximate the valid class distribution.

Thank you for point out. I do not think about that problem. I followed the example from ENet implementation

If you do not mind how I can change the my code to yours?
I mean for example if I use the batch size 4 in the 2 GPU in that cases, I have to connect 2 labels and calculate the class weights?

Assuming pipe is a DataLoader object, you could iterate it once and collect all targets via:

targets = []
for _, target in pipe:
targets = torch.stack(targets)

and calculate the class distribution later.

I hope that the target tensors are not too big to fit into your RAM.
If that’s the case, you might need to add a break statement after a certain amount of targets.
As long as you shuffle the targets and grab “enough” of them, the calculated estimation of the class distribution could come close to the real distribution.

Oh, Now I understand what you say. it really needs to loop all batches.
Thank you for your answer.