DataLoader not reading in masks correctly?

Hi,

I got a multi-class segmentation model running from end-to-end but it seems to be performing incorrectly. I did some debugging and I believe it is because the DataLoader is incorrectly reading in the masks.

For example, here is a mask with 3 classes:

And here is how the DataLoader perceives this mask. It gets read in as black and white and it seems that only one class is preserved (we no longer see a pink X on the bottom).

This is how getitem is defined for my DataLoader:

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.png"))
        image = np.array(Image.open(img_path).convert("RGB"), dtype=np.float32)
        mask = np.array(Image.open(mask_path).convert("RGB"), dtype=np.float32)
        target = torch.from_numpy(mask)
        h,w = target.shape[0], target.shape[1]
        mask = torch.empty(h, w, dtype = torch.long)
        colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        
        for k in mapping:
            # Get all indices for current class
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3)  # Check that all channels match
            mask[validx] = torch.tensor(mapping[k], dtype=torch.long)

        if self.transform is not None:
            mask = mask.numpy()
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
            
        print(f"image shape: {image.shape}\n")
        print(f"mask shape: {mask.shape}\n")
        torchvision.utils.save_image(image, "data/inside_data_loader/" + self.images[index])
        torchvision.utils.save_image(mask.float(), "data/inside_data_loader/" + self.images[index].replace(".jpg", "_mask.png"))
        return image, mask

As we see, we simply input the image and corresponding mask and then retrieve the colors from the mask and create a color map. Then before returning I save the image and corresponding mask so that way I can have an idea of how the DataLoader views them.

As for the transformations, I simply normalize and convert to a tensor as follows:

    train_transform = A.Compose(
        [
            #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            #A.Rotate(limit=35, p=1.0),
            #A.HorizontalFlip(p=0.5),
            #A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

Any idea why my model is perceiving my mask this way? Is this behavior correct?

Thank you very much!

Can you print some examples of what the colors in mapping are? I wonder if they are valid/displayable colors. You might also want to disable the “augmentations” on the mask as that could be affecting the colors. You might also consider whether it makes sense to apply augmentations to the mask, since my understanding is the mask is the label for the segmentation task.

Thank you for your response. I printed mapping right after it got defined and here is what the map contains:

{(10.0, 133.0, 1.0): 0, (243.0, 5.0, 247.0): 1, (255.0, 255.0, 255.0): 2}

So we see green, pink, and white as expected.

As for the transformations, I simply normalize and then convert the training image or mask to a tensor (I apply the same transformations to the image and the corresponding mask).

The only transformation that could seem problematic is the normalization but even with disabling that transformation and keeping the tensor conversion one, I still get the same behavior unfortunately.

Taking a closer look, I see that mask is initialized with torch.empty. You might want to do torch.zeros to prevent any unlabeled pixels from looking strange unless you are certain every single pixel will be labeled.
I would then check that mask makes sense for each of the individual classes (e.g., can you try viewing the mask with just the first class, just the second, and so on?) to make sure that nothing strange like different classes overwriting each other is happening.

As a sanity check, could you print the shapes of target, idx, and validx?

Thank you for your response.

I am certain every pixel is labeled but I switched to torch.zeros to be sure.

How do you recommend I check the mask for each class?

Nonetheless, here is my updated code:

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("RGB"))
        target = torch.from_numpy(mask)
        h,w = target.shape[0], target.shape[1]
        mask = torch.zeros(h, w, dtype = torch.long)
        colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        print(f"target shape: {target.shape}\n")

        for k in mapping:
            # Get all indices for current class
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3)  # Check that all channels match
            mask[validx] = torch.tensor(mapping[k], dtype=torch.long)
            print(f"idx shape: {idx.shape}\n")
            print(f"validx shape: {validx.shape}\n")
            
        image = transforms.ToTensor()(image)   
        print(f"image shape: {image.shape}\n")
        print(f"mask shape: {mask.shape}\n")
        torchvision.utils.save_image(image, "data/inside_data_loader/" + self.images[index])
        torchvision.utils.save_image(mask.float(), "data/inside_data_loader/" + self.images[index].replace(".jpg", "_mask.png"))
        return image, mask

Here is the output:

target shape: torch.Size([3, 1000, 1000])

idx shape: torch.Size([3, 1000, 1000])

validx shape: torch.Size([1000, 1000])

image shape: torch.Size([3, 1000, 1000])

mask shape: torch.Size([1000, 1000])

@ptrblck This procedure is based off of some solutions that you gave in order to create a mask for multi-class segmentation. If you have the time, would you mind taking a look at this? The problem is that my mask does not seem to be read in correctly to the model (loss of color and loss of class data) which thereby affects my model performance. Thank you very much!

The fact that mask is only two dimensional is an issue if it should have color information. Can you try something like

mask = torch.zeros(3, h, w, dtype = torch.long)

and then

mask[0,:,:][validx] = torch.tensor(k[0], dtype=torch.long)
mask[1,:,:][validx] = torch.tensor(k[1], dtype=torch.long)
mask[2,:,:][validx] = torch.tensor(k[2], dtype=torch.long)

There might be a way to clean up the indexing here but I think this illustrates the intent.

The issue is that mapping is from colors to classes, but classes are being assigned to the mask here rather than colors. So here we are using the “key” of the mapping for the mask as that actually contains the colors.

On the other hand, if mask is supposed to have a format like 1000x1000 or 1x1000x1000 where each position is supposed to store the class id, then you might want to reuse part of the code here

to generate something that has 3 colors and is viewable :wink:

Thank you for your prompt response!

I made those corresponding changes but now when I save the mask at the end, it gets saved as a white PNG:

Instead of this:

@eqy @ptrblck Is there something I am fundamentally misunderstanding about masks being inputted to multi-class segmentation models? How should the mask look after we update it by doing the mapping? Should it have as many channels as there are classes, in which for each channel you see a true or false in each pixel if that pixel should have the class color or not? Should I be able to save my mask as an image after the mapping and still be able to view it normally? Sorry for all the questions I am just confused.

The white image might be an issue due to the color encoding. For a floating point image you might want to scale all of the colors to be between 0.0 and 1.0 instead of 0 and 255.

For your second question, it depends on what the loss function criterion used by default is. For a classification model, a typical label format has the shape N (N = batch), where the label is the class index so we only need one value per example. If your label format is N,H,W, then this corresponds to a class index per pixel. Note that masks in this second format by default are NOT renderable as images by default. Even if we remove the batch dimension, the H,W format doesn’t have a color dimension. I think the black and white image is the result of the class index being interpreted as colors, as class 0 can be black, while class 1 and 2 and simply clamped to 1.0 and displayed as white. In other words, there is no indication that the mask being loaded is incorrect, as it should not be correctly displayable as an image without being converted into a format with color information.

I think you are spot on.

I am using nn.CrossEntropyLoss as my criterion. So my label format is (N, H, W). Multiclass Segmentation - #2 by ptrblck.

So does this mean I should go back to the following mapping?

        mask = torch.zeros(h, w, dtype = torch.long)
        colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        print(f"target shape: {target.shape}\n")

        for k in mapping:
            # Get all indices for current class
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3)  # Check that all channels match
            mask[validx] = torch.tensor(mapping[k], dtype=torch.long)

So that way we have the class index in each pixel? Then when it’s time to output the prediction at the end of the model, we just convert back to RGB like how you demonstrated in my other post?

Yes, I think it is fine to use this label format for now and see if the output from the model makes sense after training.

Thank you for your response. That makes perfect sense.

Unfortunately after training using 1 training image and 1 testing image, the model performance is incorrect. I know I need more data but I at least expect to see some colors (green or pink).

The model predicted the following mask to be all white:

As follows:

With your help from the previous post, I have the following code in which I also saved the segments as images so that way I can see what each class prediction looks like.

    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            pred = torch.sigmoid(model(x))
            out = (pred > 0.5).float()
            class_to_color = [torch.tensor([10, 133, 1], device='cuda'), torch.tensor([243, 5, 247], device='cuda'), torch.tensor([255, 255, 255], device='cuda')]
            output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float, device='cuda')
            for class_idx, color in enumerate(class_to_color):
                mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
                mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
                curr_color = color.reshape(1, 3, 1, 1)
                segment = mask*curr_color # should have shape 1, 3, 100, 100
                torchvision.utils.save_image(segment.float(), f"{folder}/segment_{class_idx}.png")
                output += segment
        torchvision.utils.save_image(output, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.float(), f"{folder}{idx}.png")

But each of the saved segments is a white PNG. I think even with 1 training and testing image, I should still see some pink or green (the 2 other classes) in the resulting prediction.

This could be a similar issue as before. Can you see what happens when the colors are scaled to be between 0.0 and 1.0 rather than 0 and 255?

Thank you for your response.

How do I go about doing that and where exactly do you suggest I do that?

e.g., can you try dividing all of the values here [torch.tensor([10, 133, 1], device='cuda'), torch.tensor([243, 5, 247], device='cuda'), torch.tensor([255, 255, 255], device='cuda')] by 255?

Thank you for your response and sorry for my late one.

I just tried that and now I got a colored output:

But when I look at the saved segments:

segment = mask*curr_color
torchvision.utils.save_image(segment.float()f"{folder}/segment_{class_idx}.png")

I get the following for classes 0, 1, and 2 respectively.

Do you think there is still something wrong going on? I’m not sure why I see black predictions for classes 0 and 2. Nonetheless this is still good progress because at least the model output has some color in it now!

How long has the model been training? You might want to print out what proportion of the pixels are predicted for each class; it might be stuck on the same class for all pixels.

The model only trains for a few seconds because I’m only using 1 training image and testing with 1 image. I know that’s nothing but I think we should still see some reasonable but garbage output right? Especially since I am using a UNET for multi-class segmentation, I thought one of the key features of UNETs are their ability to learn quickly with small data.

Interesting, in that case something might be wrong here. Can you take a look at the raw predictions from the models (just the values of the tensor) and see if they make sense?

Hi thank you for your response, I printed the following tensors:

pred = torch.sigmoid(model(x))
out = (pred > 0.5).float()
pred: tensor([[[[0.5138, 0.5134, 0.5132,  ..., 0.5134, 0.5133, 0.5129],
          [0.5128, 0.5126, 0.5120,  ..., 0.5126, 0.5123, 0.5122],
          [0.5135, 0.5133, 0.5126,  ..., 0.5129, 0.5126, 0.5124],
          ...,
          [0.5130, 0.5129, 0.5122,  ..., 0.5128, 0.5128, 0.5131],
          [0.5128, 0.5122, 0.5120,  ..., 0.5126, 0.5127, 0.5132],
          [0.5129, 0.5136, 0.5128,  ..., 0.5136, 0.5141, 0.5139]],

         [[0.5214, 0.5217, 0.5219,  ..., 0.5217, 0.5211, 0.5212],
          [0.5217, 0.5225, 0.5224,  ..., 0.5225, 0.5208, 0.5211],
          [0.5218, 0.5222, 0.5216,  ..., 0.5220, 0.5208, 0.5208],
          ...,
          [0.5212, 0.5216, 0.5210,  ..., 0.5211, 0.5205, 0.5212],
          [0.5211, 0.5212, 0.5207,  ..., 0.5207, 0.5205, 0.5208],
          [0.5206, 0.5209, 0.5202,  ..., 0.5202, 0.5202, 0.5208]],

         [[0.4996, 0.4995, 0.4993,  ..., 0.4992, 0.4993, 0.4991],
          [0.4998, 0.4996, 0.4993,  ..., 0.4992, 0.4995, 0.4989],
          [0.5001, 0.4996, 0.4997,  ..., 0.4997, 0.4996, 0.4994],
          ...,
          [0.5001, 0.5000, 0.5000,  ..., 0.4999, 0.5003, 0.4995],
          [0.5000, 0.4998, 0.5001,  ..., 0.5001, 0.5001, 0.4994],
          [0.4999, 0.4997, 0.4999,  ..., 0.4997, 0.5001, 0.4999]]]],
       device='cuda:0')

out: tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [1., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 0., 1., 0.],
          [0., 0., 1.,  ..., 1., 1., 0.],
          [0., 0., 0.,  ..., 0., 1., 0.]]]], device='cuda:0')

I don’t think pred or out make sense, especially pred because shouldn’t the values be 0, 1, or 2?