Batched index_select / tf.gather_nd

I am trying to replicate tensorflow’s gather_nd().

torch.manual_seed(0)

images = torch.rand((2, 2, 2, 3))
indices = torch.randint(0, 2, size=(2, 2, 2, 3)).long()

These are the image values:

tensor([[[[0.4963, 0.7682, 0.0885],
          [0.1320, 0.3074, 0.6341]],

         [[0.4901, 0.8964, 0.4556],
          [0.6323, 0.3489, 0.4017]]],


        [[[0.0223, 0.1689, 0.2939],
          [0.5185, 0.6977, 0.8000]],

         [[0.1610, 0.2823, 0.6816],
          [0.9152, 0.3971, 0.8742]]]])

These are the indices:

tensor([[[[0, 1, 1],
          [1, 1, 0]],

         [[1, 0, 1],
          [0, 1, 1]]],


        [[[0, 1, 1],
          [0, 0, 1]],

         [[0, 1, 1],
          [1, 1, 1]]]])

And this is the result I want to obtain:

tensor([[[[0.6323, 0.3489, 0.4017],
          [0.1610, 0.2823, 0.6816]],

         [[0.5185, 0.6977, 0.8000],
          [0.6323, 0.3489, 0.4017]]],


        [[[0.6323, 0.3489, 0.4017],
          [0.1320, 0.3074, 0.6341]],

         [[0.6323, 0.3489, 0.4017],
          [0.9152, 0.3971, 0.8742]]]])

Currently, I am using these 3 nested for-loops:

new_images = torch.zeros_like(images)

for i, batch in enumerate(indices):
    for j, dim in enumerate(batch):
        for k, row in enumerate(dim):
            new_images[i][j][k] = images[row[0]][row[1]][row[2]]

but this is reaaaallly slow.

Is there a way I could use index_select to achieve the same result?

1 Like

Would images.gather(0, indices) just work?
EDIT: Based on your example output, no it seems not to do the wanted behavior.
But this seems to work:

idx1, idx2, idx3 = indices.chunk(3, dim=3)
images[idx1, idx2, idx3].squeeze()

1 Like

@ptrblck thank you for your quick response, it works perfectly! However, for a single channel image I had to adapt it a little:

idx1, idx2, idx3 = indices.chunk(3, dim=3)
images[idx1, idx2, idx3].squeeze(3)
1 Like

Hi, @ptrblck, would you mind sharing your idea about how to do meshgrid indexing?
if I have an NHW image I, and I have 2 indexing matrices idx_h and idx_w, both in shape (NHW),
I want to extract a new image im_new, where
im_new[n][h][w] = I[n][idx_h[n][h][w]][idx_w[n][h][w]]
what would you suggest to do?

Did you try gather?

I don’t think gather can do that since it’s indexing only 1 dim
currently I used reshape to flat the indexing matrices, and reshape the resulted image back to the correct shape… but that seems kind of awkward.

Yeah, I’m not sure either if there is one clean solution using .gather.
I tried to use multiple calls to gather:

x.gather(1, idx_h).gather(2, idx_w)

, but obviously the second gather call will index x as x[i][j][idx_w[i][j][k]], which is not what we want.

There is one approach using strides, but it’s not really straightforward to understand, what is happening.
I’m really not a fan of this code and I think there might be a better approach.
However, for the sake of completeness:

idx0 = torch.arange(N).view(N, 1, 1).repeat(1, H, W).view(-1)
x.view(-1)[idx0 * x.stride()[0] + idx_h.view(-1) * x.stride(1) + idx_w.view(-1)].view(N, H, W)

Hi,
How to use the chunk() function to process the following situation?

>>> w = torch.randn(3, 4, 5)
>>> w
tensor([[[-0.7057,  0.2028,  0.5647,  0.6096, -0.3521],
         [ 0.6071, -0.1356,  0.4405, -1.0726, -0.0379],
         [ 1.0409, -1.6037, -1.3454, -1.4991, -0.5502],
         [ 0.9706,  1.2075, -0.8726,  0.3872, -1.2042]],

        [[ 1.5871,  1.0355, -0.3684, -0.0823, -0.1831],
         [ 0.7262,  2.4277, -0.1173,  0.4026,  0.8243],
         [ 1.1217, -0.6436,  1.4286,  0.6216,  0.0021],
         [ 1.0578, -0.4555, -0.2754,  0.6451,  0.9100]],

        [[ 0.1190, -0.9509,  0.2875, -0.9116,  0.6941],
         [ 0.4358, -0.6536, -0.5309, -0.5608,  0.6856],
         [ 0.0518, -0.8806,  0.4239,  1.2213,  0.7129],
         [ 0.3308, -1.5465,  1.9394, -1.4332,  1.3698]]])
>>> w.shape
torch.Size([3, 4, 5]) (batch_size, time_step, hidden)

>>> index = torch.tensor([[0, 1], [1, 2], [0, 3]])  (batch_size, k)

I want to get this result when dim=1:

tensor([[[-0.7057,  0.2028,  0.5647,  0.6096, -0.3521],
         [ 0.6071, -0.1356,  0.4405, -1.0726, -0.0379]],

        [[ 0.7262,  2.4277, -0.1173,  0.4026,  0.8243],
         [ 1.1217, -0.6436,  1.4286,  0.6216,  0.0021]],

        [[ 0.1190, -0.9509,  0.2875, -0.9116,  0.6941],
         [ 0.0518, -0.8806,  0.4239,  1.2213,  0.7129]]])
shape :(3, 2, 5)

What is the most effective and simple way to do this?
Thanks!!

You could just index your tensor w (it seems the last row in your example is wrong):

w[torch.arange(w.size(0)).unsqueeze(1), index]
1 Like

Thanks, your method is very usefully.
Sorry, the last row in my example is wrong, :sweat_smile:

tensor([[[-0.7057,  0.2028,  0.5647,  0.6096, -0.3521],
         [ 0.6071, -0.1356,  0.4405, -1.0726, -0.0379]],

        [[ 0.7262,  2.4277, -0.1173,  0.4026,  0.8243],
         [ 1.1217, -0.6436,  1.4286,  0.6216,  0.0021]],

        [[ 0.1190, -0.9509,  0.2875, -0.9116,  0.6941],
         [ 0.3308, -1.5465,  1.9394, -1.4332,  1.3698]]])

Warm Tip:
w[torch.arange(w.size(0)).unsqueeze(1).long(), index] can work.