Merge different channels into one

Hi all! I am following this tutorial on finetuning an object detector. As one of the steps, I split my mask image into separate masks and filter out masks corresponding to unwanted labels. Afterwards, I would like to merge the masks back into one channel and display it as a PIL image. The problem is: I cannot for the life of me figure out how to do this!

I’ve tried torch.cat, but this increases the size of the image in one dimension. E.g.: torch.cat(t, dim=1), where t.shape == (2,640,230) creates a tensor of size (640,460), while what I want is a tensor of size (640,230) with the values of the respective channels in the correct positions. The masks do not overlap at all, so I feel like this should be possible.

Any help is much appreciated, and if there are questions please let me know!

If each mask has no channel dimension and contains the current class index, you could simply add the masks.
Assuming the background class has the class index 0 and no other classes overlap.

Thanks, this is what I initially wanted! However, it turns out the classes overlap after all :sweat: Is there another (standard?) way of merging the masks back into one image, I guess giving preference to masks with a higher confidence score?

How would be the confidence of the mask be defined?
If you just want to keep the largest class index, you could use torch.max:

a = torch.randint(0, 4, (2, 4, 4))
b = torch.randint(0, 4, (2, 4, 4))
c = torch.max(a, b)