Suppose I have a tensor A with the following shape:
torch.Size([5, 16, 5000, 3])
I also have a mask of the same shape:
torch.Size([5, 16, 5000, 3])
If I apply this mask M directly to the tensor A via
A = A[M]
I end up with a flattened tensor with single dimension.
However, I would like to mask out only along dimension 2. In other words, I would like to get a tensor of the shape
torch.Size([5, 16, 5000 - N, 3])
where N is the number of entries for which mask M is False.
What is the way of doing this?