I am trying to replicate tensorflow’s gather_nd().
torch.manual_seed(0)
images = torch.rand((2, 2, 2, 3))
indices = torch.randint(0, 2, size=(2, 2, 2, 3)).long()
These are the image values:
tensor([[[[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341]],
[[0.4901, 0.8964, 0.4556],
[0.6323, 0.3489, 0.4017]]],
[[[0.0223, 0.1689, 0.2939],
[0.5185, 0.6977, 0.8000]],
[[0.1610, 0.2823, 0.6816],
[0.9152, 0.3971, 0.8742]]]])
These are the indices:
tensor([[[[0, 1, 1],
[1, 1, 0]],
[[1, 0, 1],
[0, 1, 1]]],
[[[0, 1, 1],
[0, 0, 1]],
[[0, 1, 1],
[1, 1, 1]]]])
And this is the result I want to obtain:
tensor([[[[0.6323, 0.3489, 0.4017],
[0.1610, 0.2823, 0.6816]],
[[0.5185, 0.6977, 0.8000],
[0.6323, 0.3489, 0.4017]]],
[[[0.6323, 0.3489, 0.4017],
[0.1320, 0.3074, 0.6341]],
[[0.6323, 0.3489, 0.4017],
[0.9152, 0.3971, 0.8742]]]])
Currently, I am using these 3 nested for-loops:
new_images = torch.zeros_like(images)
for i, batch in enumerate(indices):
for j, dim in enumerate(batch):
for k, row in enumerate(dim):
new_images[i][j][k] = images[row[0]][row[1]][row[2]]
but this is reaaaallly slow.
Is there a way I could use index_select to achieve the same result?