Multiclass segmentation U-net masks format

I’m quite new to pytorch so I know I might say things which are not really completely correct.
I’m trying out of curiosity to implement a multiclass segmentation using U-net code found here Since I have my own dataset of images i’ve modified the Class “SimDataset” to upload my own dataset plus images in the following way:

The images are 224x224 png RGB images and the mask are indexed images with four classes (index 1,2,3 features and 0 background).

from import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
def init(self, image_paths, mask_paths, count , transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths

def transforms(self, image, mask):
    #img = img.resize((wsize, baseheight), PIL.Image.ANTIALIAS)
    #image = transforms.Resize(size=(64, 64))(image)
    #mask = transforms.Resize(size=(64, 64))(mask)
    image = image.resize((64, 64), PIL.Image.NEAREST)
    mask = mask.resize((64, 64), PIL.Image.NEAREST)
    image = TF.to_tensor(image)
    mask = TF.to_tensor(mask) 
    return [image, mask]

def __getitem__(self, index):
    image =[index])
    mask =[index])

    x, y = self.transforms(image, mask)
    return [x, y]

def __len__(self):
    return len(self.image_paths)

trans = transforms.Compose([
transforms.ToTensor(), transforms.RandomHorizontalFlip()

train_set = SimDataset(train_paths, train_masks_paths,100, transform = trans)
val_set = SimDataset(train_paths, train_masks_paths,11, transform = trans)

image_datasets = {
‘train’: train_set, ‘val’: val_set

batch_size = 1

dataloaders = {
‘train’: DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
‘val’: DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)

dataset_sizes = {
x: len(image_datasets[x]) for x in image_datasets.keys()


I’ve been reading also here that passing the mask as tensor with the command to.Tensor() is not a good habit since the command normalises the values of the class index (and indeed it does) leading to instabilieties. I would like to ask how (using if possible the code snippet i’m using) how to pass correclty the masks during the training.

Also the number of classes that I need to train since I want 3 classes to be recognized in the training is nb_class = 3 ?( n_total_index - 1 meaning 0,1,2,3 but without the background )

Thanks in advance!

Instead of to_tensor you could create the mask tensor via torch.from_numpy().
For 3 classes and a background class, I would assume your model should use nb_class = 4, which would most likely correspond to the number of output units.

Hi, Diauro,
I use the function of transforms.ToTensor() in the implemention of customized dataset code, and have not met any problem yet.

Thanks for tips! Now I actually have a doubt with the code i’m using regarding how the masks are loaded. Since I have four classes in the masks, using the command PIL.nearest I get basically a 1x64x64 image while if i’m not wrong is a tensor like 1xnb_classx64x64 or i’m wrong?

I was looking here pytorch_unet at the line:

import torchvision.utils

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    return inp

# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)
for x in [inputs.numpy(), masks.numpy()]:
    print(x.min(), x.max(), x.mean(), x.std())


where basically the outpur is 25x6x192x192. So I guess the way I’m loading the masks is wrong. Any idea?

Hi songyuc,
Thanks for you feedback! Do you use multiclass segmentation by chance? If you could check the reply I gave to ptrblck and if you could give me a feedback on how you deal with the mask loading and manipulation would be great. Sorry I’m quite new to pytorch so I might ask some very basic questions.


If you are using a multi-class segmentation use case and therefore nn.CrossEntropyLoss or nn.NLLLoss, your mask should not contain a channel dimension, but instead contain the class indices in the shape [batch_size, height, width].

PIL.NEAREST is a valid option, as it won’t distort your color codes or class indices.
Since you are loading the image via PIL, I assume your mask is an RGB image, where each color represents a class?

If so, you should map the color values to the corresponding class index.

Let me know, if that would work for you or if you need more information. :slight_smile:

1 Like

Yes, I am learning about semantic segmentation, too.
I read your and ptrblck’s replies,
and I think there might be an error as you said you had wanted a 1x64x64 image but got a 25x6x192x192 tensor.
I think ptrblck’s suggesstion is good, do the mask loading as:

  1. Reading the image file;
  2. Mapping the pixel value into the train class ID.

I saw an example here How make customised dataset for semantic segmentation? of how to do a dataset but I must admit that I still have doubts.

Basically I have PNG image but already labelled with index 0 , 1 , 2, 3 23, It’s a costum pallette that I did so techically I should avoid the mapping as in the example or I’m wrong?

self.mapping = {
            0: 0,
            255: 1              

def mask_to_class(self, mask):
        for k in self.mapping:
            mask[mask==k] = self.mapping[k]
        return mask

Do you have a snippet of a code that can create the class indices in the shape [batch_size, height, width] as you said?

Currently I was trying just for one class to apply the mask = torch.from_numpy(mask) but i’m not sure if it’s ok.

def transforms(self, image, mask):
    #img = img.resize((wsize, baseheight), PIL.Image.ANTIALIAS)

    image = image.resize((64, 64), PIL.Image.NEAREST)
    mask = mask.resize((64, 64), PIL.Image.NEAREST)
    mask = np.array(mask)
    image = TF.to_tensor(image)
    mask = torch.from_numpy(mask) 
    return [image, mask]

Now the mask is actually between zero and one as it should be (the problem for more
indices I still need to solve it of course) but once I try to train I get this:

ValueError: Target size (torch.Size([1, 64, 64])) must be the same as input size (torch.Size([1, 1, 64, 64]))

I tried to add in the calculation of the loss: target_ = torch.empty(batch_size, 1,64,64) target =

def calc_loss(pred, target, metrics, bce_weight=0.5):
    target_ = torch.empty(batch_size, 1,64,64)
    target =
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

This does not crash but I have Nan as output in the loss. So yes I guess i’m quite confused and trying to read as much but cannot figure out a lot at the moment so whatever help would be great :slight_smile: thanks again!

In that case you don’t need any mapping, as your mask will already contain the class indices.

torch.empty uses uninitialized memory to create the tensor, so it might contain invalid values (e.g. NaNs).
For a multi-class segmentation (each pixel belongs to one class only). you should use nn.CrossEntropyLoss instead of nn.BCEWithLogitsLoss.
The latter criterion can be used for a multi-label classification/segmentation (each pixel can belong to zero, one, or more classes).

Thanks for the quick reply!

I must admit that is not clear to me how to create the mask. Once I load the mask with PIL I pass my mask as

mask = torch.from_numpy(mask)

but then is not clear how to create this one:

Because I guess once I create this tensor like [batch_size, num_classes, height , width] I guess I can skip this step on the training loop

That gives me the error with the Nan due to what you said about the memory that might contain those values.


Hi Songyuc,

Could you post if it’s not a problem (or send me the link to some examples), of how you create the multiclass masks?

Many thanks!

Hey Diauro, you can find many examples and pre-trained models at
Great resource!

Yes, here is an example of tranforming the mask data into train ID,

Thanks everyone for the help! following you advices I managed to modify and run correctly the unet (more or less, the loss kind of saturates and does not go particularly down but that might be my dataset).

Quick question is there a way to easely implement dropout in the U-net snippet found here ?

Best and thanks again for the help!

1 Like