Extract batchwise multiple coordinates with pytorch functions (similar to tf.gather_nd)

Oh I’m sorry. I completely misunderstood your use case.

This should work and is much easier:

img_feat = torch.randn(batch_size, c, h, w)
x = torch.empty(batch_size, nb_points, dtype=torch.long).random_(h)
y = torch.empty(batch_size, nb_points, dtype=torch.long).random_(w)
points =  img_feat[torch.arange(batch_size)[:, None], :, x, y]
feats = torch.stack([torch.stack([img_feat[i,:,x[i,j].long(), y[i,j].long()] for j in range(x.shape[1])]) for i in range(batch_size)])
print((points == feats).all())
2 Likes