How to do the tf.gather_nd in pytorch?

emmmmm…, it might work, i only figured out such a naive solution:
the indices must be a 2d tensor and indices.size(1)==len(params.size()).

def gather_nd(params, indices, name=None):
    '''
    the input indices must be a 2d tensor in the form of [[a,b,..,c],...], 
    which represents the location of the elements.
    '''
    
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1
    
    for i in range(ndim)[::-1]:
        idx += indices[i] * m 
        m *= params.size(i)
    
    return torch.take(params, idx)
2 Likes