Hi there,
I would like to index a tensor using a mask that contains a variable number of True elements for each batch row. Particularly, imagine you are in the following situation:
x = torch.randn((2, 5, 768))
mask = torch.tensor([[False, True, False, False, False], [False, True, False, False, True]])
I would like to create a function that extracts the indexes with True
from x
and then makes sure that the output is stored in a tensor that preserves the batch dimension. Particularly, I would like an output y
that has size (batch_size, 2, 768)
(in this case is 2 because the maximum number of True values in a row is 2).
In the easiest case where the two batch elements have the same number of True values, I can simply do x[mask]
and then reshape the tensor. But if they are not, it will require some sort of padding?