I have one batched tensor with matrices:
batch_size = 64
num_elements = 500
channels = 2
a = torch.rand(batch_size, num_elements, channels) # [64, 500, 2]
And one tensor with indices, that describe how to index each previous matrix:
indices = 80
b = (torch.rand(batch_size, indices) * num_elements).long() # [64, 80]
I want to perform an indexing such that each matrix in the batch is indexed with the corresponding index in the other tensor. The resulting matrix has shape [64, 80, 2]
How can I do that?