Indexing 3D tensor using 2D tensor

New with Pytorch here,

I have a 3D tensor (Batch X features X attributes), and an indexing tensor of 2D (Batch X features).
In the indexing tensor each row in the batch is the index of the desired attribute of that feature (for example: indexing = torch.tensor([[1, 2],[2, 5],[5, 2]]) a tensor of 3X2 size. For the first row I want attribute 1 for the first feature and 2 attribute for the second feature, for the second row I want attribute 2 for the first feature and 5 attribute for the second feature etc…).

The output tensor should be of size (Batch X features).

I tried to use .gather but I am not sure it is right for this, or which dim to use it on.

How can I achieve that?

torch.gather(x, 2, index[..., None])
1 Like

Thanks a lot!
This worked very well.