How to indexing a tensor with a tensor in multi-dimension?

a = torch.tensor([[1,2,3,4],[5,6,7,8]])
idx = torch.tensor([[0,2,1],[2,3,0]])

# How to do it in batch ?
c_1 = a[0][idx[0]].view(1,-1)
c_2 = a[1][idx[1]].view(1,-1)
c = torch.cat((c_1, c_2), dim=0)

The desired output is:

tensor([[1, 3, 2],
        [7, 8, 5]])

I tried a[idx], however, it goes wrong.

What is idx and the desired output in the snippet above ?

Hi, I made a typo, and just made the correction for it.

idx is the indexes for selecting the elements in the tensor.

import torch
a = torch.tensor([[1,2,3,4],[5,6,7,8]])
idx = torch.tensor([[0,2,1],[2,3,0]])
idx2 = idx + torch.arange(idx.size(0)).view(-1, 1) * a.size(1)
c = a.view(-1)[idx2]

It works as long as a and idx have only two dimensions and idx.size(0) <= a.size(0).

1 Like

Thank you fo you help @LeviViana. It is surely a way to solve it, however, I wonder whether pytorch can support a[idx] , is it possible @albanD ?

Hi,

gather is what your want!

c = a.gather(1, idx)
5 Likes

a_tensor[idx] is supported, I often use this way.