Hello, I have the following problem that would be nice if it were possible to solve it efficiently with some vectorized operations.

In my dataset, I have a discrete set of items, each tagged with a subset of 20 possible categories.

During training, I output a distribution over the set of items in my dataset which I sample from.

When training with batches, I am given the tensor `items_per_category`

of shape `[num_categories, num_items]`

(the tensor is only square due to padding - each category entry has a variable number of items that are tagged with this category).

I also have a batch of tensors `binary_categories`

of shape `[batch_size, num_categories]`

.

Let’s say we have a batch element `indexx = [1, 0, 1]`

, then I would like to perform the following operation: `categories = items_per_category[index]`

, which would retrieve rows 0 and 2 from `items_per_category`

.

For clarity, the following works as intended:

```
>>> items_per_category = torch.tensor([[1, 2, 3], [1, 4, 7], [2, 3, 5]], dtype=torch.int32)
>>> index = torch.tensor([1, 0, 1]).type(torch.BoolTensor)
>>> items_per_category[x, :]
tensor([[1, 2, 3],
[2, 3, 5]], dtype=torch.int32)
```

However, if I have a batch of indices, it does not.

```
>>> batch_index = torch.tensor([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> items_per_category[batch_index.type(torch.BoolTensor), :]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: too many indices for tensor of dimension 2
```

Now, I would like to perform a batched indexing. I am wondering if it is possible to obtain, for each batch element, a flat tensor with unique discrete elements, and overall obtain a square tensor (perhaps padded). Otherwise, since each batch element has a variable number of 1s, not sure what shape I can expect the output of this batched indexing to be.

Thanks!

EDIT: as an alternative, would I be better off converting the binary indices to the actual integer positions? E.g. ` x => [0, 2]`

? Unsure about the approach I’m taking