How to do the tf.gather_nd in pytorch?

In Tensorflow we can do the below operation by tf.gather_nd, but how to do this in pytorch?
Simple indexing into a matrix:

indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']

Slice indexing into a matrix:

indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]

Indexing into a 3-tensor:

indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]

indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]

indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']

Batched indexing into a matrix:

indices = [[[0, 0]], [[0, 1]]]
params = [['a', 'b'], ['c', 'd']]
output = [['a'], ['b']]
1 Like

Hi. Did you ever discover the best way to deal with this issue?

Yes, there are equivalent operations in pytorch. Try something like the following:

Simple indexing into matrix:

x = torch.randn(2, 2)
indices = torch.ByteTensor([[0, 0],[1,1]])
x.masked_select(indices)

Slice indexing into matrix:

x = torch.randn(2, 2)
indices = torch.LongTensor([1, 0])
x.index_select(0, indices)

Indexing into a 3-tensor:

x = torch.randn(1, 4, 2)
x[:,:,1]

Batched indexing into a matrix:

x = torch.randn(2, 2)
indices = torch.LongTensor([[0, 0],[1,1]])
[x[i] for i in indices]
2 Likes

Are you sure these snippets are correct? Just comparing to TensorFlow’s example:

indices = torch.ByteTensor([[0, 0], [1, 1]])
params  = torch.Tensor([[1, 2], [3, 4]])
params.masked_select(indices)

Tensorflow documentation says the output should be tensor([ 1., 4.]), but your code gives tensor([ 3., 4.])

That one should be

indices = torch.tensor([[ 0,  1],[ 0,  1]])
params  = torch.tensor([[1, 2], [3, 4]])
params[indices.tolist()]

it should be

rows_list = [0, 1]
col_list = [0, 1]
params[[rows_list , col_list]]
3 Likes

Are you not missing something in the Batched indexing into a matrix block at the end? If you do it that way you have to loop over all indices, for the dim=0 in your case. My question would be, is there a fast way in pytorch to do the gather_nd where I have a 3D-matrix that stores all the indices and a 3D-matrix that has all the values and I would like to create a new 3D-matrix where each value is mapped to the indices from the 3D-matrix.

Could you show me an example of what you were thinking?
torch.gather can do something like this (https://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather) but it might not be exactly what you’re looking for.

I have found the solution myself. Thank you anyway!

Hi, how did you solve the problem? Would you mind showing your code? Thx.:slightly_smiling_face:

@ palimboa Can you share width me how do you sovle it ? Thank you

hello , do you solve it? I also meets the problem

emmmmm…, it might work, i only figured out such a naive solution:
the indices must be a 2d tensor and indices.size(1)==len(params.size()).

def gather_nd(params, indices, name=None):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...], 
    which represents the location of the elements.
    '''
    
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)
    
    return torch.take(params, idx)
2 Likes

Great! Thank you. I think it’s useful.

Could you provide a more general version to fulfill tf.gather_nd, not only limited to 2d tensor?

For example, if I want to gather 3d tensor with size CxHxW by giving the indices with size Nx2, where the indices values correspond to the coordinates in HxW grid, how can I get the results with size CxN?

Actually in tensorflow, this can be easily implemented by tf.gather_nd…

I think you can use advanced indexing for this:

C, H, W = 3, 4, 4
x = torch.arange(C*H*W).view(C, H, W)
print(x)
idx = torch.tensor([[0, 0],
                    [1, 1],
                    [2, 2],
                    [3, 3]])

print(x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))])

@ptrblck Thanks! I have found your previous answer which solved my problem.

The following codes are used for validation:

import torch

batch_size = 2
c, h, w = 256, 38, 65
nb_points = 784
nb_regions = 128

img_feat = torch.randn(batch_size, c, h, w).cuda()
x = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(h).cuda()
y = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(w).cuda()

# method 1
result_1 = img_feat[torch.arange(batch_size)[:, None], :, x.view(batch_size, -1), y.view(batch_size, -1)]
result_1 = result_1.view(batch_size, nb_regions, nb_points, -1)

# method 2
result_2 = img_feat.new(batch_size, nb_regions, nb_points, img_feat.size(1)).zero_()
for i in range(batch_size):
    for j in range(x.shape[1]):
        for k in range(x.shape[2]):
            result_2[i, j, k] = img_feat[i, :, x[i, j, k].long(), y[i, j, k].long()]

print((result_1 == result_2).all())

Thank you for this solution. How can I do this if I have images in batches, NXCXHXW and index tensor are also in batches NXH*WX2. So that My output is NXCXHXW. It would be really great.

Could you post a minimal example, how the index tensor should be used to index which values in your input?

I have been achieving this using map following is a code and output

N, C, H, W = 2 , 3 , 2, 2
img = torch.arange(N*C*H*W).view(N,C, H, W)
idx = torch.randint(0,2,(N,H*W,2))

maskit = lambda x,idx: x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))]

masked = torch.stack([*map(lambda x:maskit(x[0],x[1]),zip(img,idx))])
final = masked.expand(1,*masked.shape).permute(1,3,2,0).view(*img.shape)

OUTPUT

print(img)
**out :** tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])

print(idx)
**out :** tensor([[[0, 1],
         [1, 0],
         [1, 0],
         [1, 1]],

        [[1, 1],
         [0, 1],
         [1, 0],
         [0, 0]]])
print(masked)
**out :** tensor([[[ 1,  5,  9],
         [ 2,  6, 10],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[15, 19, 23],
         [13, 17, 21],
         [14, 18, 22],
         [12, 16, 20]]])

print(final)
**out :** tensor([[[[ 1,  2],
          [ 2,  3]],

         [[ 5,  6],
          [ 6,  7]],

         [[ 9, 10],
          [10, 11]]],


        [[[15, 13],
          [14, 12]],

         [[19, 17],
          [18, 16]],

         [[23, 21],
          [22, 20]]]])

print(final.shape)

**out :** torch.Size([2, 3, 2, 2])

Can’t we do this without using map.