Batched boolean indexing

Hello, I have the following problem that would be nice if it were possible to solve it efficiently with some vectorized operations.
In my dataset, I have a discrete set of items, each tagged with a subset of 20 possible categories.
During training, I output a distribution over the set of items in my dataset which I sample from.

When training with batches, I am given the tensor items_per_category of shape [num_categories, num_items] (the tensor is only square due to padding - each category entry has a variable number of items that are tagged with this category).
I also have a batch of tensors binary_categories of shape [batch_size, num_categories].
Let’s say we have a batch element indexx = [1, 0, 1], then I would like to perform the following operation: categories = items_per_category[index], which would retrieve rows 0 and 2 from items_per_category.
For clarity, the following works as intended:

>>> items_per_category = torch.tensor([[1, 2, 3], [1, 4, 7], [2, 3, 5]], dtype=torch.int32)
>>> index = torch.tensor([1, 0, 1]).type(torch.BoolTensor)
>>> items_per_category[x, :]
tensor([[1, 2, 3],
        [2, 3, 5]], dtype=torch.int32)

However, if I have a batch of indices, it does not.

>>> batch_index = torch.tensor([[1, 0, 0],
        [1, 0, 1],
        [0, 0, 1]])
>>> items_per_category[batch_index.type(torch.BoolTensor), :]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for tensor of dimension 2

Now, I would like to perform a batched indexing. I am wondering if it is possible to obtain, for each batch element, a flat tensor with unique discrete elements, and overall obtain a square tensor (perhaps padded). Otherwise, since each batch element has a variable number of 1s, not sure what shape I can expect the output of this batched indexing to be.

EDIT: as an alternative, would I be better off converting the binary indices to the actual integer positions? E.g. x => [0, 2]? Unsure about the approach I’m taking

If you are using a batched index tensor, you would have to either add a batch dimension to the items tensor:

items_per_category = items_per_category.unsqueeze(0).repeat(3, 1, 1)

or you could index both dimensions by removing the : from the indexing operation.

Let me know, if I misunderstood your use case.