Hi there,

I would like to index a tensor using a mask that contains a variable number of True elements for each batch row. Particularly, imagine you are in the following situation:

```
x = torch.randn((2, 5, 768))
mask = torch.tensor([[False, True, False, False, False], [False, True, False, False, True]])
```

I would like to create a function that extracts the indexes with `True`

from `x`

and then makes sure that the output is stored in a tensor that preserves the batch dimension. Particularly, I would like an output `y`

that has size `(batch_size, 2, 768)`

(in this case is 2 because the maximum number of True values in a row is 2).

In the easiest case where the two batch elements have the same number of True values, I can simply do `x[mask]`

and then reshape the tensor. But if they are not, it will require some sort of padding?