Implement tf.gather_nd in PyTorch

There have been some discussions about the topic, but none of them solved my problem.

Existing discussion threads:

I am not familiar with TF. And I want to translate a TF code to PyTorch.

The TF code goes as follows:

sample_idx = tf.cast(tf.round(sample_idx), 'int32')
g_val = tf.gather_nd(sample_grid, sample_idx) 

where sample_idx is of size [3211264, 4] and sample_grid is of size [1, 1, 32, 32, 32].

Moreover, there are negative values in the sample_idx. And I tried to print it as follows:

tensor([[  0., -18.,  16.,  21.],
        [  0., -18.,  16.,  20.],
        [  0., -18.,  16.,  20.],
        ...,
        [  0.,  56.,  -0.,  -5.],
        [  0.,  56.,  -0.,  -5.],
        [  0.,  56.,  -0.,  -5.]], device='cuda:0')

And the g_val is of size [3211264] in TF.

I wonder how to do this operation with PyTorch?

Thanks for your help.

According to TF’s API:

Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value.

Therefore, I simply add a clamp on the index values. And now it is working.

def gather_nd(params, indices):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...], 
    which represents the location of the elements.
    '''
    max_value = functools.reduce(operator.mul, list(params.size())) - 1
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)

    idx[idx < 0] = 0
    idx[idx > max_value] = 0
    return torch.take(params, idx)

Thanks for debugging this issue!

That seems like a weird strategy to me.

I found that the results are not the same as TF when batch size is larger than 1.

For example, the size of sample_idx and sample_grid are [57802752, 4] and [6, 32, 32, 32], respectively.

sample_grid = torch.from_numpy(np.load('sample_grid.npy', sample_grid)).cuda()
sample_idx = torch.from_numpy(np.load('sample_idx.npy', sample_idx)).cuda()

g_val = gather_nd(sample_grid, sample_idx)
print(torch.sum(g_val[0, 0, 0, ...]))   # tensor(52197., device='cuda:0')

However, in TF:

_sample_grid = np.load('sample_grid.npy', sample_grid)
_sample_idx = np.load('sample_idx.npy', sample_idx)

g_val = tf.gather_nd(sample_grid, sample_idx)
_g_val = sess.run(g_val, {
    sample_grid: _sample_grid,
    sample_idx: _sample_idx
})
print(np.sum(_g_val[0, 0, 0, ...]))    # 43380.0

You can download the sample_grid.npy and sample_idx.npy from here.

Who can help me out :frowning:

I wrote a small program to test the behavior of TF and the self-implemented gather_nd in PyTorch.

import tensorflow as tf
import torch
import numpy as np

import functools
import operator


def gather_nd(params, indices):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...], 
    which represents the location of the elements.
    '''
    max_value = functools.reduce(operator.mul, list(params.size())) - 1
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)

    idx[idx < 0] = 0
    idx[idx > max_value] = 0
    return torch.take(params, idx)


idx = np.array([
	[0, 0, 0, 0],
	[0, 0, 1, 0],
	[0, 1, 1, 0],
	[2, 1, 1, 0],
	[1, 1, 4, 0],
])
mtx = np.reshape(range(120), (2, 3, 4, 5))
print(mtx)

tf_result = tf.gather_nd(mtx, idx)
with tf.Session() as sess:
	print('TF', sess.run(tf_result))

mtx = torch.from_numpy(mtx)
idx = torch.from_numpy(idx)
torch_result = gather_nd(mtx, idx)
print('Torch', torch_result)

And the output is shown as following:

TF [ 0  5 25  0  0]
Torch tensor([  0,   5,  25,   0, 100])

It seem that there are something wrong with the self-implemented one.

I preprocessed the indices before torch.take and make sure that the index value in each dimension is correct.

The modified gather_nd is listed as follows:

def gather_nd(params, indices):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...], 
    which represents the location of the elements.
    '''
    # Normalize indices values
    params_size = list(params.size())
    assert len(params_size) == indices.size(1)

    indices[indices < 0] = 0
    for idx, ps in enumerate(params_size):
        indices[indices[:, idx] >= ps] = 0

    # Generate indices
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)

    return torch.take(params, idx)

Hi, I modified @hzxie ‘s code as below:
It support the case that indices.size(1) is not equal to len(param_size)

def gather_nd(params, indices):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...],
    which represents the location of the elements.
    '''
    # Normalize indices values
    params_size = list(params.size())

    assert len(indices.size()) == 2
    assert len(params_size) >= indices.size(1)

    # Generate indices
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1

    for i in range(ndim)[::-1]:
        idx += indices[i] * m
        m *= params.size(i)

    params = params.reshape((-1, *tuple(torch.tensor(params.size()[ndim:]))))
    return params[idx]

Hope that code helps someone :smiley:

I am also quite confused with torch implementation of tf.gather_nd at first, but I found tf.gather_nd seems like a variant of slicing. I have my implementation below and hope it will help someone :smile:

def gather_nd(params, indices):
  """params is of "n" dimensions and has size [x1, x2, x3, ..., xn], indices is of 2 dimensions  and has size [num_samples, m] (m <= n)"""
  assert type(indices) == torch.Tensor
  return params[indices.transpose(0,1).long().numpy().tolist()]

testing

Tensorflow samples

indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]


indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]


indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
          [['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']

for torch

>>> params = torch.rand(2,2,2)
>>> params
tensor([[[0.4685, 0.2514],
         [0.0624, 0.0797]],

        [[0.4989, 0.1414],
         [0.6970, 0.6825]]])
>>> indices = torch.Tensor([[1]])
>>> gather_nd(params, indices)
tensor([[[0.4989, 0.1414],
         [0.6970, 0.6825]]])
>>> indices = torch.Tensor([[0, 1], [1, 0]])
>>> gather_nd(params, indices)
tensor([[0.0624, 0.0797],
        [0.4989, 0.1414]])
>>> indices = torch.Tensor([[0, 0, 1], [1, 0, 1]])
>>> gather_nd(params, indices)
tensor([0.2514, 0.1414])

I created a version that works on 4D indices, 3D params and returns a 4D output.

Example:

params is a float32 Tensor of shape [BS, seq_len, emb_size]
indices is a int64 Tensor of shape [BS, seq_len, 10, 2]
output is a float32 Tensor of shape [BS, seq_len, 10, emb_size]

This implementation is very slow and ugly. But maybe someone can think of a clever trick to speed up the for loop?

def torch_gather_nd(params: torch.Tensor,
                    indices: torch.Tensor) -> torch.Tensor:
    """
    Perform the tf.gather_nd on torch.Tensor. Although working, this implementation is
    quite slow and 'ugly'. You should not care to much about performance when using
    this function. I encourage you to think about how to optimize this.

    This function has not been tested properly. It has only been tested empirically
    and shown to yield identical results compared to tf.gather_nd. Still, use at your
    own risk.

    Does not support the `batch_dims` argument that tf.gather_nd does support. This is
    something for future work.

    :param params: (Tensor) - the source Tensor
    :param indices: (LongTensor) - the indices of elements to gather
    :return output: (Tensor) – the destination tensor
    """
    assert indices.dtype == torch.int64, f"indices must be torch.LongTensor, got {indices.dtype}"
    assert indices.shape[-1] <= len(params.shape), f'The last dimension of indices can be at most the rank ' \
                                                   f'of params ({len(params.shape)})'

    # Define the output shape. According to the  documentation of tf.gather_nd, this is:
    # "indices.shape[:-1] + params.shape[indices.shape[-1]:]"
    output_shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]

    # Initialize the output Tensor as an empty one.
    output = torch.zeros(size=output_shape, device=params.device, dtype=params.dtype)

    # indices_to_fill is a list of tuple containing the indices to fill in `output`
    indices_to_fill = list(itertools.product(*[range(x) for x in output_shape[:-1]]))

    # Loop over the indices_to_fill and fill the `output` Tensor
    for idx in indices_to_fill:
        index_value = indices[idx]

        if len(index_value.shape) == 0:
            index_value = torch.Tensor([0, index_value.item()])

        value = params[index_value.view(-1, 1).tolist()].view(-1)
        output[idx] = value

    return output