Use argmax indices to select values from a tensor

I have 2 tensors A and B both having a shape of 2 x 10 x 5 x 2

I can get the max values and their indices with values, indices = A.max(dim=-2)

How do I use the indices to select the corresponding values in tensor B?

I tried with a smaller tensor but should work for you:

input1 = torch.randn(2,2,3)
input2 = torch.randn(2,2,3)
print(input2)

values1, indices1 = input1.max(dim=-1, keepdim=True)
print(ā€™\nā€™,indices1, ā€˜\nā€™)

print(torch.gather(input2, -1, indices1))

Legend, thank you very much Karthik.

1 Like