I have a source multidimensional tensor of shape (a,b,c,c,d)
which stores vectors/data of size d
, and another tensor of shape (a,b,e,2)
that stores e
indices of size 2. 2-dimensional values correspond to the indices 2-3 of the data tensor (both dimensions of size c
). Note that both tensors share the same a,b
dimension sizes.
What I want to do is to use these indices to retrieve e
rows of size d
from the first tensor. So that, the output tensor should have size (a,b,e,d)
, i.e. e
vectors of size d
along the a,b
dimensions.
a, b, c, d = 3,5,7,9
e = 11
data = torch.rand(a,b,c,c,d)
inds = torch.randint(0,c, size=(a,b,e,2))
res = data[:, :, inds[:,:,:,0], inds[:,:,:,1],:]
print(' - Obtained shape:', res.shape)
print(' - Desired shape:', (a,b,e,d))
# - Obtained shape: (3, 5, 3, 5, 11, 9)
# - Desired shape: (3, 5, 11, 9)