Hello, I have the following problem that would be nice if it were possible to solve it efficiently with some vectorized operations.
In my dataset, I have a discrete set of items, each tagged with a subset of 20 possible categories.
During training, I output a distribution over the set of items in my dataset which I sample from.
When training with batches, I am given the tensor items_per_category
of shape [num_categories, num_items]
(the tensor is only square due to padding - each category entry has a variable number of items that are tagged with this category).
I also have a batch of tensors binary_categories
of shape [batch_size, num_categories]
.
Let’s say we have a batch element indexx = [1, 0, 1]
, then I would like to perform the following operation: categories = items_per_category[index]
, which would retrieve rows 0 and 2 from items_per_category
.
For clarity, the following works as intended:
>>> items_per_category = torch.tensor([[1, 2, 3], [1, 4, 7], [2, 3, 5]], dtype=torch.int32)
>>> index = torch.tensor([1, 0, 1]).type(torch.BoolTensor)
>>> items_per_category[x, :]
tensor([[1, 2, 3],
[2, 3, 5]], dtype=torch.int32)
However, if I have a batch of indices, it does not.
>>> batch_index = torch.tensor([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> items_per_category[batch_index.type(torch.BoolTensor), :]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: too many indices for tensor of dimension 2
Now, I would like to perform a batched indexing. I am wondering if it is possible to obtain, for each batch element, a flat tensor with unique discrete elements, and overall obtain a square tensor (perhaps padded). Otherwise, since each batch element has a variable number of 1s, not sure what shape I can expect the output of this batched indexing to be.
Thanks!
EDIT: as an alternative, would I be better off converting the binary indices to the actual integer positions? E.g. x => [0, 2]
? Unsure about the approach I’m taking