How to do the tf.gather_nd in pytorch?

@ptrblck Thanks! I have found your previous answer which solved my problem.

The following codes are used for validation:

import torch

batch_size = 2
c, h, w = 256, 38, 65
nb_points = 784
nb_regions = 128

img_feat = torch.randn(batch_size, c, h, w).cuda()
x = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(h).cuda()
y = torch.empty(batch_size, nb_regions, nb_points, dtype=torch.long).random_(w).cuda()

# method 1
result_1 = img_feat[torch.arange(batch_size)[:, None], :, x.view(batch_size, -1), y.view(batch_size, -1)]
result_1 = result_1.view(batch_size, nb_regions, nb_points, -1)

# method 2
result_2 = img_feat.new(batch_size, nb_regions, nb_points, img_feat.size(1)).zero_()
for i in range(batch_size):
    for j in range(x.shape[1]):
        for k in range(x.shape[2]):
            result_2[i, j, k] = img_feat[i, :, x[i, j, k].long(), y[i, j, k].long()]

print((result_1 == result_2).all())