Multiclass semantic segmentation with RGB masks steps clarification


I am really confused about the multi-class semantic segmentation steps

I have

  • 1000 Images
  • 6 classes of ground truth RGB masks as PNGs

I am using nn.CrossEntropyLoss as my loss function.
Evaluation metric is IoU

My first approach was to

  1. Convert original ground truth masks to one_hot_encoded based on pixel colour
  2. Convert the images and masks to compose.ToTensor()
  3. Train and use IOU metrics for evaluations

Then I learnt that crossEntropy suppose to take categorical indexed values 1, 2, 3, 4, 5, 6

My current approach

  1. Convert original ground truth masks to Class categories. Each pixel belong to specific class
  2. Convert the class categorised masks to tensor using torch.tensor()
  3. Convert the images to compose.ToTensor()
  4. NB: Note the difference functions used to convert to tensor.

Issues & questions

  1. However, I seem to loose class dimension? why is that?
  2. Metric to use - I see that IoU take pixel wise inputs? do I need to somehow convert the categorical predicted class & categorical ground truth masks and convert it back to its original colour code before evaluate?

I assume you are wondering about the shape of the target which should be [batch_size, height, width] (note the missing “class” dimension) containing class indices in the range [0, nb_classes-1]?
If so, then note that the class dimension would be useless since you are already using indices to represent the class.

So do I need to do something like this?

for x, y in iterator:
  x, y =,
  y = torch.unsqueeze(y, x.shape[1])` not sure if this is correct?

Adopted from

class TrainEpoch(Epoch):

    def __init__(self, model, loss, metrics, optimizer, logger, device='cpu', verbose=True, writer=None ):
        self.writer = writer
        self.optimizer = optimizer
        self.logger = logger
        self.log_on_start = True

    def on_epoch_start(self):

    def batch_update(self, x, y):
        print('Shape',  x.shape, y.shape)

        prediction = self.model.forward(x)

        if isinstance(prediction, dict):
            prediction = prediction['out']

            if self.log_on_start:
                Warning("prediction is a dictionary, using 'out' key")
                self.logger.warning("prediction is a dictionary, using 'out' key")
                self.log_on_start = False
        print(prediction.shape, y.shape)
        loss = self.loss(prediction, y)
        return loss, prediction

I get some errors like this

           for x, y in iterator:
                x, y =,
                print('Shape iterator',  x.shape, y.shape, torch.unsqueeze(y, dim=x.shape[1]).shape, x.shape[1])

                # Add number of channels to y
                        # Get the number of channels from tensor x
                num_channels = x.size(1)

                # If y lacks the channel dimension, add it with the same number of channels as x
                y = y.unsqueeze(1).expand(-1, num_channels, -1, -1)

Shape iterator torch.Size([2, 3, 512, 512]) torch.Size([2, 512, 512]) torch.Size([2, 512, 512, 1]) 3
new shape torch.Size([2, 3, 512, 512]) torch.Size([2, 3, 512, 512])
Shape Train torch.Size([2, 3, 512, 512]) torch.Size([2, 3, 512, 512])
prediction shape torch.Size([2, 8, 512, 512]) torch.Size([2, 3, 512, 512])

train Epoch 0:   0%|          | 0/35 [00:04<?, ?it/s]
Traceback (most recent call last):
  File "/kristina/dev/training/UNet/", line 96, in <module>
    train_logs =, epoch=i)
  File "/kristina/dev/training/UNet/../smp/utils/", line 65, in run
    loss, y_pred = self.batch_update(x, y)
  File "/kristina/dev//-training/UNet/../smp/utils/", line 139, in batch_update
    loss = self.loss(prediction, y)
  File "/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/modules/", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/kristina/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [2, 3, 512, 512]


The error message prints the invalid target shape so check my previous post to see which shape is expected.

1 Like

I see! I think I misunderstood what I have to do. I guess I should not have out_classes in my output? so for

# Final Convolution UNET
self.conv_last = nn.Conv2d(64, out_classes, kernel_size=1)
       # Apply softmax activation along the 'out_classes' dimension 
       prediction_probs = F.softmax(prediction, dim=1)

        # Take argmax along the 'out_classes' dimension to get the final class predictions
        predicted_classes = torch.argmax(prediction_probs, dim=1)

        loss = self.loss(predicted_classes, y.float())
        print('loss', loss)

Then I get

RuntimeError: "host_softmax" not implemented for 'Long'
srun: error: gcn4: task 0: Exited with exit code 1
srun: Terminating StepId=3138493.0

Thanks for the help in advance :bowing_woman:

and this code
loss = self.loss(predicted_classes.float(), y.float())

gives me this error

    loss, y_pred = self.batch_update(x, y)
  File "/dev/UNet/../smp/utils/", line 141, in batch_update
  File "/ miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/", line 487, in 
  File "/ miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/autograd/", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
srun: error: gcn4: task 0: Exited with exit code

No, the model output is expected to contain raw logits (so remove the F.softmax) in the shape [batch_size, nb_classes, height, width] while the target should contain class indices in the range [0, nb_classes-1] and the shape [batch_size, height, width]. The previous error pointed to the target, not the model output.