How to specify colors for each class with draw_segmentation_masks for segmentation data

I’m currently working with a segmentation dataset. I decided to use the draw_segmentation_masks function is torchvision to visualize the data, but the function returns different colors in the segmentation masks for a single object category in different images, and I can’t find a way to specify the classes/labels in the function arguments. Any workaround this?

If I understand correctly, it sounds like the issue is due to the fact that not every image has the same classes segmented/masked. Could you define your own map of classes to colors, and create a tuple of colors for the specific classes that show up for each image (they would need to be in the order they appear in the masks tensor), and pass that via the colors argument?

Apologies for the last response. I tried to do what you suggested, but I’m not sure it’s possible. I’m reading the data with a custom Dataset class, as described here : TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation. The issue is, the code, gets the number of unique objects in an image, not how many times they occur, and I can’t think of any way to get the number of occurences of each object in an image. If I could do that, I would be able to use the classes to colors mapping as you suggested.

I don’t think you need to count anything, you would simply do something like
colors = [label2color[labels[i]] for i in range(N)] and pass that as the colors list, with label2color being your own predfined mapping from classes to colors.