Broadcast torch.gather (4-D array with 2-D index)

Hi,

I think this question is very similar to this post (which explains how to access an 4-D tensor through 3-D index), but I still can’t figure it out.

I want to access an 4-D tensor through 2-D index like following:

dim1 = 3
dim2 = 4
dim3 = 5
dim4 = 10
source = torch.randn((dim1,dim2,dim3,dim4))
index = torch.randint(dim3, (dim1, dim2))
ret =  torch.FloatTensor(dim1, dim2, dim3)

for i in range(dim1):
    for j in range(dim2):
        ret[i,j,:] = source[i, j, index[i,j], :]

I assume ret should be initialized with the shapes [dim1, dim2, dim4], otherwise you’ll get a shape mismatch error in the posted loop.
If that’s the case, this code should work:

dim1 = 3
dim2 = 4
dim3 = 5
dim4 = 10
source = torch.randn((dim1,dim2,dim3,dim4))
index = torch.randint(dim3, (dim1, dim2))
ret =  torch.zeros(dim1, dim2, dim4)

for i in range(dim1):
    for j in range(dim2):
        ret[i,j,:] = source[i, j, index[i,j], :]


ret2 = torch.zeros(dim1, dim2, dim4)
ret2[torch.arange(dim1).unsqueeze(1), torch.arange(dim2)] = source[torch.arange(dim1).unsqueeze(1), torch.arange(dim2), index]

print((ret == ret2).all())
> tensor(True)
1 Like