Use argmax indices to select values from a tensor

I have 2 tensors A and B both having a shape of 2 x 10 x 5 x 2

I can get the max values and their indices with values, indices = A.max(dim=-2)

How do I use the indices to select the corresponding values in tensor B?

1 Like

I tried with a smaller tensor but should work for you:

input1 = torch.randn(2,2,3)
input2 = torch.randn(2,2,3)
print(input2)

values1, indices1 = input1.max(dim=-1, keepdim=True)
print(’\n’,indices1, ‘\n’)

print(torch.gather(input2, -1, indices1))

1 Like

Legend, thank you very much Karthik.

1 Like

I am having a similar problem, rather than pulling the values from the A, I want to edit the values at the indexes’ locations. Any suggestion?