I have one batched tensor with matrices:
batch_size = 64 num_elements = 500 channels = 2 a = torch.rand(batch_size, num_elements, channels) # [64, 500, 2]
And one tensor with indices, that describe how to index each previous matrix:
indices = 80 b = (torch.rand(batch_size, indices) * num_elements).long() # [64, 80]
I want to perform an indexing such that each matrix in the batch is indexed with the corresponding index in the other tensor. The resulting matrix has shape
[64, 80, 2]
How can I do that?