I have 2 tensors `A`

and `B`

both having a shape of 2 x 10 x 5 x 2

I can get the max values and their indices with `values, indices = A.max(dim=-2)`

How do I use the `indices`

to select the corresponding values in tensor `B`

?

I have 2 tensors `A`

and `B`

both having a shape of 2 x 10 x 5 x 2

I can get the max values and their indices with `values, indices = A.max(dim=-2)`

How do I use the `indices`

to select the corresponding values in tensor `B`

?

1 Like

I tried with a smaller tensor but should work for you:

input1 = torch.randn(2,2,3)

input2 = torch.randn(2,2,3)

print(input2)

values1, indices1 = input1.max(dim=-1, keepdim=True)

print(’\n’,indices1, ‘\n’)

print(torch.gather(input2, -1, indices1))

1 Like

Legend, thank you very much Karthik.

1 Like

I am having a similar problem, rather than pulling the values from the A, I want to edit the values at the indexes’ locations. Any suggestion?