How to do the tf.gather_nd in pytorch?

That’s quite a tough one to crack.
So far I only came up with this approach:

idx_chunked = idx.chunk(2, 2)
masked = img[torch.arange(N).view(N, 1), :, idx_chunked[0].squeeze(), idx_chunked[1].squeeze()]
final = masked.expand(1,*masked.shape).permute(1,3,2,0).view(*img.shape)

And as you can see, I kept your last expand and reshaping.

A quick %timeit gives a speedup of approx. 2x:

  • your approach: 53.4 µs ± 461 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
  • the new one: 26.4 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1 Like

Thank you very much. Now I need to see the speed-up in my real task. I hope it will speed up even more in my case, as there are huge calculations to perform. :slight_smile:

1 Like

I know this is a very dirty workaround, but it would give you the exact tf.gather_nd behavior.

It basically wraps tf.gather_nd into a pytorch function and then calls it.

1 Like

After spending some more time with the topic, I think @Sau_Tak was pretty close to what I needed. The only thing I had to add was some reshaping before and after to make it work for my use-case in arbitrary dimensions.

Note that now the for loop might be a little bit slow. which I think has been addressed by other posts. But for me that works fine now.

import torch, skimage

def gather_nd(params, indices):
    '''
    4D example
    params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
    indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices
    
    returns: tensor shaped [m_1, m_2, m_3, m_4]
    
    ND_example
    params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
    indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices
    
    returns: tensor shaped [m_1, ..., m_1]
    '''

    out_shape = indices.shape[:-1]
    indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
    ndim = indices.shape[0]
    indices = indices.long()
    idx = torch.zeros_like(indices[0], device=indices.device).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)
    out = torch.take(params, idx)
    return out.view(out_shape)

## Example

image_t = torch.from_numpy(skimage.data.astronaut())
params = ptcompat.torch_tile_nd(image_t.view((1, 1, 512, 512, 3)), [4, 16, 1, 1, 1]) # batch of stack of images
indices = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(16), torch.arange(128), torch.arange(128), torch.arange(3)), dim=-1) # get 128 x 128 image slice from each item

out = gather_nd(params, indices)
print(out.shape) # >> [4, 16, 128, 128, 3]
plt.imshow(out[0, 0, ...])
2 Likes

Building on top of @xinwei_he, can be done something like this:

m = torch.rand(3,3)
idx = torch.LongTensor([[2,2], [1,2]])
m[list(idx.T)] # will do the trick
5 Likes

I think @Mudit_Bachhawat’s solution is great - it is simple and effective. Not sure why it does not receive many upvotes. Did I miss anything? Thanks!

Man, you`ve made my day!
Thank you, your solution works perfect for me! Was stuck porting my code from TF to pytorch. This function is awesome.

1 Like
def gather_nd(x,indices):
    newshape=indices.shape[:-1]+x.shape[indices.shape[-1]:]
    indices=indices.view(-1,ishape[-1]).tolist()
    out=torch.cat([x.__getitem__(tuple(i)) for i in indices])
    return out.reshape(newshape)

I created a version that works on 4D indices, 3D params and returns a 4D output.

Example:

params is a float32 Tensor of shape [BS, seq_len, emb_size]
indices is a int64 Tensor of shape [BS, seq_len, 10, 2]
output is a float32 Tensor of shape [BS, seq_len, 10, emb_size]

This implementation is very slow and ugly. But maybe someone can think of a clever trick to speed up the for loop?

def torch_gather_nd(params: torch.Tensor,
                    indices: torch.Tensor) -> torch.Tensor:
    """
    Perform the tf.gather_nd on torch.Tensor. Although working, this implementation is
    quite slow and 'ugly'. You should not care to much about performance when using
    this function. I encourage you to think about how to optimize this.

    This function has not been tested properly. It has only been tested empirically
    and shown to yield identical results compared to tf.gather_nd. Still, use at your
    own risk.

    Does not support the `batch_dims` argument that tf.gather_nd does support. This is
    something for future work.

    :param params: (Tensor) - the source Tensor
    :param indices: (LongTensor) - the indices of elements to gather
    :return output: (Tensor) – the destination tensor
    """
    assert indices.dtype == torch.int64, f"indices must be torch.LongTensor, got {indices.dtype}"
    assert indices.shape[-1] <= len(params.shape), f'The last dimension of indices can be at most the rank ' \
                                                   f'of params ({len(params.shape)})'

    # Define the output shape. According to the  documentation of tf.gather_nd, this is:
    # "indices.shape[:-1] + params.shape[indices.shape[-1]:]"
    output_shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]

    # Initialize the output Tensor as an empty one.
    output = torch.zeros(size=output_shape, device=params.device, dtype=params.dtype)

    # indices_to_fill is a list of tuple containing the indices to fill in `output`
    indices_to_fill = list(itertools.product(*[range(x) for x in output_shape[:-1]]))

    # Loop over the indices_to_fill and fill the `output` Tensor
    for idx in indices_to_fill:
        index_value = indices[idx]

        if len(index_value.shape) == 0:
            index_value = torch.Tensor([0, index_value.item()])

        value = params[index_value.view(-1, 1).tolist()].view(-1)
        output[idx] = value

    return output

The solution is good, now trying to test that on a batch of data.

Can’t work on 5D tensor.

It works
but yet to be tested on real training step.

Adapting @Mudit_Bachhawat 's solution for 4D params and 4D indices

def gather_nd_torch(params, indices):
params = torch.moveaxis(params, (0, 1, 2, 3), (0, 3, 1, 2))
indices = torch.moveaxis(indices, (0, 1, 2, 3), (0, 3, 1, 2))
indices = indices.type(torch.int64)
gathered = params[list(indices.T)]
gathered = torch.moveaxis(gathered, (0, 1, 2, 3), (3, 2, 0, 1))

return gathered

Works for arbitrary shapes.

See tests on Colab: Google Colaboratory

def gather_nd(params, indices):
    """ The same as tf.gather_nd but batched gather is not supported yet.
    indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params:

    output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]]

    Args:
        params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}]
        indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n.

    Returns: gathered Tensor.
        shape [y_0,y_2,...y_{k-2}] + params.shape[m:] 

    """
    orig_shape = list(indices.shape)
    num_samples = np.prod(orig_shape[:-1])
    m = orig_shape[-1]
    n = len(params.shape)

    if m <= n:
        out_shape = orig_shape[:-1] + list(params.shape)[m:]
    else:
        raise ValueError(
            f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
        )

    indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
    output = params[indices]    # (num_samples, ...)
    return output.reshape(out_shape).contiguous()```
3 Likes

Might be too late for the OP but here it is…

The majority of this implementation is from Michael Jungo.
I just ported it compatible to leading batch dimension.

def gather_nd_torch(params, indices, batch_dim=1):
    """ A PyTorch porting of tensorflow.gather_nd
    This implementation can handle leading batch dimensions in params, see below for detailed explanation.

    The majority of this implementation is from Michael Jungo @ https://stackoverflow.com/a/61810047/6670143
    I just ported it compatible to leading batch dimension.

    Args:
      params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c].
      indices: a tensor of dimension [b1, ..., bn, x, m]
      batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n.

    Returns:
      gathered: a tensor of dimension [b1, ..., bn, x, c].

    Example:
    >>> batch_size = 5
    >>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
    >>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
    >>> gathered = gather_nd_torch(inputs, pos, batch_dim=3)
    >>> gathered.shape
    torch.Size([5, 5, 5, 12, 32])

    >>> inputs_tf = tf.convert_to_tensor(inputs.numpy())
    >>> pos_tf = tf.convert_to_tensor(pos.numpy())
    >>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
    >>> gathered_tf.shape
    TensorShape([5, 5, 5, 12, 32])

    >>> gathered_tf = torch.from_numpy(gathered_tf.numpy())
    >>> torch.equal(gathered_tf, gathered)
    True
    """
    batch_dims = params.size()[:batch_dim]  # [b1, ..., bn]
    batch_size = np.cumprod(list(batch_dims))[-1]  # b1 * ... * bn
    c_dim = params.size()[-1]  # c
    grid_dims = params.size()[batch_dim:-1]  # [g1, ..., gm]
    n_indices = indices.size(-2)  # x
    n_pos = indices.size(-1)  # m

    # reshape leadning batch dims to a single batch dim
    params = params.reshape(batch_size, *grid_dims, c_dim)
    indices = indices.reshape(batch_size, n_indices, n_pos)

    # build gather indices
    # gather for each of the data point in this "batch"
    batch_enumeration = torch.arange(batch_size).unsqueeze(1)
    gather_dims = [indices[:, :, i] for i in range(len(grid_dims))]
    gather_dims.insert(0, batch_enumeration)
    gathered = params[gather_dims]

    # reshape back to the shape with leading batch dims
    gathered = gathered.reshape(*batch_dims, n_indices, c_dim)
    return gathered

I tested it for my use cases and please let me know if this implementation can be improved…

I have also did a very simple speed test (yeah, run 100 or 1000 times and compare), this implementation is faster than TF’s original gather_nd.

Anyway, hope this can help someone potentially.

I had a use case where I needed multiple batch dimensions and multiple channels so I came up with A pytorch implementation of torch_gather_nd with multiple batch dim and multiple channel dim support. · GitHub

I’m not sure how performant it is since it requires flattening out the channel dimensions with one index entry per channel elements.