How to do the tf.gather_nd in pytorch?

After spending some more time with the topic, I think @Sau_Tak was pretty close to what I needed. The only thing I had to add was some reshaping before and after to make it work for my use-case in arbitrary dimensions.

Note that now the for loop might be a little bit slow. which I think has been addressed by other posts. But for me that works fine now.

import torch, skimage

def gather_nd(params, indices):
    '''
    4D example
    params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
    indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices
    
    returns: tensor shaped [m_1, m_2, m_3, m_4]
    
    ND_example
    params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
    indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices
    
    returns: tensor shaped [m_1, ..., m_1]
    '''

    out_shape = indices.shape[:-1]
    indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
    ndim = indices.shape[0]
    indices = indices.long()
    idx = torch.zeros_like(indices[0], device=indices.device).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)
    out = torch.take(params, idx)
    return out.view(out_shape)

## Example

image_t = torch.from_numpy(skimage.data.astronaut())
params = ptcompat.torch_tile_nd(image_t.view((1, 1, 512, 512, 3)), [4, 16, 1, 1, 1]) # batch of stack of images
indices = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(16), torch.arange(128), torch.arange(128), torch.arange(3)), dim=-1) # get 128 x 128 image slice from each item

out = gather_nd(params, indices)
print(out.shape) # >> [4, 16, 128, 128, 3]
plt.imshow(out[0, 0, ...])
2 Likes