I have a tensor, say representing some images, with a shape
[batch_size, channel, height, width], and a mask tensor with a shape
[batch_size, channel]. All elements in the mask tensor are bools.
True means I want to slice that channel while
False means I don’t. For example:
image = torch.randn(100, 3, 224, 224) mask = torch.randn(100, 3).ge(0.5)
I wanted to slice
image in the easiest way so I tried
image[:, mask] but this didn’t work.
I realized that I have to add an extra condition: the number of
dim = 1 is the same. Thus, the shape of the sliced tensor is
[batch_size, num_of_true, height, width].