I have a tensor, say representing some images, with a shape [batch_size, channel, height, width]
, and a mask tensor with a shape [batch_size, channel]
. All elements in the mask tensor are bools. True
means I want to slice that channel while False
means I don’t. For example:
image = torch.randn(100, 3, 224, 224)
mask = torch.randn(100, 3).ge(0.5)
I wanted to slice image
in the easiest way so I tried image[:, mask]
but this didn’t work.
I realized that I have to add an extra condition: the number of True
s in dim = 1
is the same. Thus, the shape of the sliced tensor is [batch_size, num_of_true, height, width]
.