Gather command for a 4-D tensor

Hi everyone,

I have a lookup/indexing tensor of size (2, 4). And I also have an input tensor of size (2, 4, 2, 4).

For example, the lookup/indexing tensor is the following: index = [ [0, 0, 0, 2], [0, 0, 1, 2] ]

My goal is to generate an output tensor of the same size as the input tensor (2, 4, 2, 4) by gathering according to the lookup indices along the dim = 1. To illustrate, in this case I want

output[0] = torch.cat( [input[0, index[0,0], :, :], input[0, index[0,1], :, :], input[0, index[0,2], :, :], input[0, index[0,3], :, :], dim=1)
output[1] = torch.cat( [input[1, index[1,0], :, :], input[1, index[1,1], :, :], input[1, index[1,2], :, :], input[1, index[1,3], :, :], dim=1)

which should yield (for the particular index mentioned above):

output[0] = torch.cat( [input[0, 0, :, :], input[0, 0, :, :], input[0, 0, :, :], input[0, 2, :, :], dim=1)
output[1] = torch.cat( [input[1, 0, :, :], input[1, 0, :, :], input[1, 1, :, :], input[1, 2, :, :], dim=1)

Is there a neat way to do this using torch.gather()?

Thank you:)

You could directly index the tensor without using gather.
I’ve changed the shape a bit (used stack instead of cat and created a final output tensor) for a better comparison:

index = torch.tensor([[0, 0, 0, 2], [0, 0, 1, 2]])
input = torch.randn(2, 3, 2, 3)

output0 = torch.stack( (input[0, index[0,0], :, :], input[0, index[0,1], :, :], input[0, index[0,2], :, :], input[0, index[0,3], :, :]))
output1 = torch.stack( (input[1, index[1,0], :, :], input[1, index[1,1], :, :], input[1, index[1,2], :, :], input[1, index[1,3], :, :]))

output = torch.stack((output0, output1))

res = input[torch.arange(2).unsqueeze(1), index]
print((res == output).all())
> tensor(True)