# How to do the tf.gather_nd in pytorch?

In Tensorflow we can do the below operation by tf.gather_nd, but how to do this in pytorch?
Simple indexing into a matrix:

``````indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
``````

Slice indexing into a matrix:

``````indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
``````

Indexing into a 3-tensor:

``````indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]

indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]

indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']
``````

Batched indexing into a matrix:

``````indices = [[[0, 0]], [[0, 1]]]
params = [['a', 'b'], ['c', 'd']]
output = [['a'], ['b']]``````
1 Like

Hi. Did you ever discover the best way to deal with this issue?

Yes, there are equivalent operations in pytorch. Try something like the following:

Simple indexing into matrix:

``````x = torch.randn(2, 2)
indices = torch.ByteTensor([[0, 0],[1,1]])
``````

Slice indexing into matrix:

``````x = torch.randn(2, 2)
indices = torch.LongTensor([1, 0])
x.index_select(0, indices)
``````

Indexing into a 3-tensor:

``````x = torch.randn(1, 4, 2)
x[:,:,1]
``````

Batched indexing into a matrix:

``````x = torch.randn(2, 2)
indices = torch.LongTensor([[0, 0],[1,1]])
[x[i] for i in indices]
``````
2 Likes

Are you sure these snippets are correct? Just comparing to TensorFlowâ€™s example:

``````indices = torch.ByteTensor([[0, 0], [1, 1]])
params  = torch.Tensor([[1, 2], [3, 4]])
``````

Tensorflow documentation says the output should be `tensor([ 1., 4.])`, but your code gives `tensor([ 3., 4.])`

That one should be

``````indices = torch.tensor([[ 0,  1],[ 0,  1]])
params  = torch.tensor([[1, 2], [3, 4]])
params[indices.tolist()]
``````

it should be

``````rows_list = [0, 1]
col_list = [0, 1]
params[[rows_list , col_list]]
``````
3 Likes

Are you not missing something in the `Batched indexing into a matrix` block at the end? If you do it that way you have to loop over all indices, for the dim=0 in your case. My question would be, is there a fast way in pytorch to do the `gather_nd` where I have a 3D-matrix that stores all the indices and a 3D-matrix that has all the values and I would like to create a new 3D-matrix where each value is mapped to the indices from the 3D-matrix.

Could you show me an example of what you were thinking?
`torch.gather` can do something like this (https://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather) but it might not be exactly what youâ€™re looking for.

I have found the solution myself. Thank you anyway!

Hi, how did you solve the problem? Would you mind showing your code? Thx.

@ palimboa Can you share width me how do you sovle it ? Thank you

hello , do you solve it? I also meets the problem

emmmmmâ€¦, it might work, i only figured out such a naive solution:
the indices must be a 2d tensor and indices.size(1)==len(params.size()).

``````def gather_nd(params, indices, name=None):
'''
the input indices must be a 2d tensor in the form of [[a,b,..,c],...],
which represents the location of the elements.
'''

indices = indices.t().long()
ndim = indices.size(0)
idx = torch.zeros_like(indices[0]).long()
m = 1

for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)

``````
2 Likes

Great! Thank you. I think itâ€™s useful.

Could you provide a more general version to fulfill tf.gather_nd, not only limited to 2d tensor?

For example, if I want to gather 3d tensor with size CxHxW by giving the indices with size Nx2, where the indices values correspond to the coordinates in HxW grid, how can I get the results with size CxN?

Actually in tensorflow, this can be easily implemented by tf.gather_ndâ€¦

I think you can use advanced indexing for this:

``````C, H, W = 3, 4, 4
x = torch.arange(C*H*W).view(C, H, W)
print(x)
idx = torch.tensor([[0, 0],
[1, 1],
[2, 2],
[3, 3]])

print(x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))])
``````

@ptrblck Thanks! I have found your previous answer which solved my problem.

The following codes are used for validation:

``````import torch

batch_size = 2
c, h, w = 256, 38, 65
nb_points = 784
nb_regions = 128

img_feat = torch.randn(batch_size, c, h, w).cuda()
x = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(h).cuda()
y = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(w).cuda()

# method 1
result_1 = img_feat[torch.arange(batch_size)[:, None], :, x.view(batch_size, -1), y.view(batch_size, -1)]
result_1 = result_1.view(batch_size, nb_regions, nb_points, -1)

# method 2
result_2 = img_feat.new(batch_size, nb_regions, nb_points, img_feat.size(1)).zero_()
for i in range(batch_size):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
result_2[i, j, k] = img_feat[i, :, x[i, j, k].long(), y[i, j, k].long()]

print((result_1 == result_2).all())
``````

Thank you for this solution. How can I do this if I have images in batches, NXCXHXW and index tensor are also in batches NXH*WX2. So that My output is NXCXHXW. It would be really great.

Could you post a minimal example, how the index tensor should be used to index which values in your input?

I have been achieving this using map following is a code and output

``````N, C, H, W = 2 , 3 , 2, 2
img = torch.arange(N*C*H*W).view(N,C, H, W)
idx = torch.randint(0,2,(N,H*W,2))

maskit = lambda x,idx: x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))]

``````

OUTPUT

``````print(img)
**out :** tensor([[[[ 0,  1],
[ 2,  3]],

[[ 4,  5],
[ 6,  7]],

[[ 8,  9],
[10, 11]]],

[[[12, 13],
[14, 15]],

[[16, 17],
[18, 19]],

[[20, 21],
[22, 23]]]])

print(idx)
**out :** tensor([[[0, 1],
[1, 0],
[1, 0],
[1, 1]],

[[1, 1],
[0, 1],
[1, 0],
[0, 0]]])
**out :** tensor([[[ 1,  5,  9],
[ 2,  6, 10],
[ 2,  6, 10],
[ 3,  7, 11]],

[[15, 19, 23],
[13, 17, 21],
[14, 18, 22],
[12, 16, 20]]])

print(final)
**out :** tensor([[[[ 1,  2],
[ 2,  3]],

[[ 5,  6],
[ 6,  7]],

[[ 9, 10],
[10, 11]]],

[[[15, 13],
[14, 12]],

[[19, 17],
[18, 16]],

[[23, 21],
[22, 20]]]])

print(final.shape)

**out :** torch.Size([2, 3, 2, 2])

``````

Canâ€™t we do this without using map.