Index tensor with mask that have variable number of True values per batch

Hi there,

I would like to index a tensor using a mask that contains a variable number of True elements for each batch row. Particularly, imagine you are in the following situation:

x = torch.randn((2, 5, 768))
mask = torch.tensor([[False, True, False, False, False], [False, True, False, False, True]])

I would like to create a function that extracts the indexes with True from x and then makes sure that the output is stored in a tensor that preserves the batch dimension. Particularly, I would like an output y that has size (batch_size, 2, 768) (in this case is 2 because the maximum number of True values in a row is 2).

In the easiest case where the two batch elements have the same number of True values, I can simply do x[mask] and then reshape the tensor. But if they are not, it will require some sort of padding?