# How to do the tf.gather_nd in pytorch?

hello , do you solve it? I also meets the problem

emmmmm…, it might work, i only figured out such a naive solution:
the indices must be a 2d tensor and indices.size(1)==len(params.size()).

``````def gather_nd(params, indices, name=None):
'''
the input indices must be a 2d tensor in the form of [[a,b,..,c],...],
which represents the location of the elements.
'''

indices = indices.t().long()
ndim = indices.size(0)
idx = torch.zeros_like(indices).long()
m = 1

for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)

``````
2 Likes

Great! Thank you. I think it’s useful.

Could you provide a more general version to fulfill tf.gather_nd, not only limited to 2d tensor?

For example, if I want to gather 3d tensor with size CxHxW by giving the indices with size Nx2, where the indices values correspond to the coordinates in HxW grid, how can I get the results with size CxN?

Actually in tensorflow, this can be easily implemented by tf.gather_nd…

I think you can use advanced indexing for this:

``````C, H, W = 3, 4, 4
x = torch.arange(C*H*W).view(C, H, W)
print(x)
idx = torch.tensor([[0, 0],
[1, 1],
[2, 2],
[3, 3]])

print(x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))])
``````

@ptrblck Thanks! I have found your previous answer which solved my problem.

The following codes are used for validation:

``````import torch

batch_size = 2
c, h, w = 256, 38, 65
nb_points = 784
nb_regions = 128

img_feat = torch.randn(batch_size, c, h, w).cuda()
x = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(h).cuda()
y = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(w).cuda()

# method 1
result_1 = img_feat[torch.arange(batch_size)[:, None], :, x.view(batch_size, -1), y.view(batch_size, -1)]
result_1 = result_1.view(batch_size, nb_regions, nb_points, -1)

# method 2
result_2 = img_feat.new(batch_size, nb_regions, nb_points, img_feat.size(1)).zero_()
for i in range(batch_size):
for j in range(x.shape):
for k in range(x.shape):
result_2[i, j, k] = img_feat[i, :, x[i, j, k].long(), y[i, j, k].long()]

print((result_1 == result_2).all())
``````

Thank you for this solution. How can I do this if I have images in batches, NXCXHXW and index tensor are also in batches NXH*WX2. So that My output is NXCXHXW. It would be really great.

Could you post a minimal example, how the index tensor should be used to index which values in your input?

I have been achieving this using map following is a code and output

``````N, C, H, W = 2 , 3 , 2, 2
img = torch.arange(N*C*H*W).view(N,C, H, W)
idx = torch.randint(0,2,(N,H*W,2))

maskit = lambda x,idx: x[list((torch.arange(x.size(0)), *idx.chunk(2, 1)))]

``````

OUTPUT

``````print(img)
**out :** tensor([[[[ 0,  1],
[ 2,  3]],

[[ 4,  5],
[ 6,  7]],

[[ 8,  9],
[10, 11]]],

[[[12, 13],
[14, 15]],

[[16, 17],
[18, 19]],

[[20, 21],
[22, 23]]]])

print(idx)
**out :** tensor([[[0, 1],
[1, 0],
[1, 0],
[1, 1]],

[[1, 1],
[0, 1],
[1, 0],
[0, 0]]])
**out :** tensor([[[ 1,  5,  9],
[ 2,  6, 10],
[ 2,  6, 10],
[ 3,  7, 11]],

[[15, 19, 23],
[13, 17, 21],
[14, 18, 22],
[12, 16, 20]]])

print(final)
**out :** tensor([[[[ 1,  2],
[ 2,  3]],

[[ 5,  6],
[ 6,  7]],

[[ 9, 10],
[10, 11]]],

[[[15, 13],
[14, 12]],

[[19, 17],
[18, 16]],

[[23, 21],
[22, 20]]]])

print(final.shape)

**out :** torch.Size([2, 3, 2, 2])

``````

Can’t we do this without using map.

That’s quite a tough one to crack.
So far I only came up with this approach:

``````idx_chunked = idx.chunk(2, 2)
masked = img[torch.arange(N).view(N, 1), :, idx_chunked.squeeze(), idx_chunked.squeeze()]
``````

And as you can see, I kept your last expand and reshaping.

A quick `%timeit` gives a speedup of approx. 2x:

• your approach: 53.4 µs ± 461 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
• the new one: 26.4 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1 Like

Thank you very much. Now I need to see the speed-up in my real task. I hope it will speed up even more in my case, as there are huge calculations to perform. 1 Like

I know this is a very dirty workaround, but it would give you the exact tf.gather_nd behavior.

It basically wraps tf.gather_nd into a pytorch function and then calls it.

1 Like

After spending some more time with the topic, I think @Sau_Tak was pretty close to what I needed. The only thing I had to add was some reshaping before and after to make it work for my use-case in arbitrary dimensions.

Note that now the for loop might be a little bit slow. which I think has been addressed by other posts. But for me that works fine now.

``````import torch, skimage

def gather_nd(params, indices):
'''
4D example
params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices

returns: tensor shaped [m_1, m_2, m_3, m_4]

ND_example
params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices

returns: tensor shaped [m_1, ..., m_1]
'''

out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape
indices = indices.long()
idx = torch.zeros_like(indices, device=indices.device).long()
m = 1

for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)

## Example

image_t = torch.from_numpy(skimage.data.astronaut())
params = ptcompat.torch_tile_nd(image_t.view((1, 1, 512, 512, 3)), [4, 16, 1, 1, 1]) # batch of stack of images
indices = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(16), torch.arange(128), torch.arange(128), torch.arange(3)), dim=-1) # get 128 x 128 image slice from each item

out = gather_nd(params, indices)
print(out.shape) # >> [4, 16, 128, 128, 3]
plt.imshow(out[0, 0, ...])
``````
3 Likes

Building on top of @xinwei_he, can be done something like this:

``````m = torch.rand(3,3)
idx = torch.LongTensor([[2,2], [1,2]])
m[list(idx.T)] # will do the trick
``````
5 Likes

I think @Mudit_Bachhawat’s solution is great - it is simple and effective. Not sure why it does not receive many upvotes. Did I miss anything? Thanks!

Thank you, your solution works perfect for me! Was stuck porting my code from TF to pytorch. This function is awesome.

1 Like
``````def gather_nd(x,indices):
newshape=indices.shape[:-1]+x.shape[indices.shape[-1]:]
indices=indices.view(-1,ishape[-1]).tolist()
out=torch.cat([x.__getitem__(tuple(i)) for i in indices])
return out.reshape(newshape)
``````

I created a version that works on 4D indices, 3D params and returns a 4D output.

Example:

``````params is a float32 Tensor of shape [BS, seq_len, emb_size]
indices is a int64 Tensor of shape [BS, seq_len, 10, 2]
output is a float32 Tensor of shape [BS, seq_len, 10, emb_size]
``````

This implementation is very slow and ugly. But maybe someone can think of a clever trick to speed up the `for` loop?

``````def torch_gather_nd(params: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
"""
Perform the tf.gather_nd on torch.Tensor. Although working, this implementation is
quite slow and 'ugly'. You should not care to much about performance when using
this function. I encourage you to think about how to optimize this.

This function has not been tested properly. It has only been tested empirically
and shown to yield identical results compared to tf.gather_nd. Still, use at your
own risk.

Does not support the `batch_dims` argument that tf.gather_nd does support. This is
something for future work.

:param params: (Tensor) - the source Tensor
:param indices: (LongTensor) - the indices of elements to gather
:return output: (Tensor) – the destination tensor
"""
assert indices.dtype == torch.int64, f"indices must be torch.LongTensor, got {indices.dtype}"
assert indices.shape[-1] <= len(params.shape), f'The last dimension of indices can be at most the rank ' \
f'of params ({len(params.shape)})'

# Define the output shape. According to the  documentation of tf.gather_nd, this is:
# "indices.shape[:-1] + params.shape[indices.shape[-1]:]"
output_shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]

# Initialize the output Tensor as an empty one.
output = torch.zeros(size=output_shape, device=params.device, dtype=params.dtype)

# indices_to_fill is a list of tuple containing the indices to fill in `output`
indices_to_fill = list(itertools.product(*[range(x) for x in output_shape[:-1]]))

# Loop over the indices_to_fill and fill the `output` Tensor
for idx in indices_to_fill:
index_value = indices[idx]

if len(index_value.shape) == 0:
index_value = torch.Tensor([0, index_value.item()])

value = params[index_value.view(-1, 1).tolist()].view(-1)
output[idx] = value

return output
``````

The solution is good, now trying to test that on a batch of data.

Can’t work on 5D tensor.