Apply index of maximum element in one Tensor to another Tensor

I want to apply index from torch.max to another tensor

For example,

a = tr.FloatTensor([4, 1], [3, 10]])
b = tr.FloatTensor([1, 2], [3, 4]])

_, idx_a = tr.max(a, 1) # [0, 1]

b[idx_a] # expected result is [1.4]
         # actual result is [[1,2], [3,4]]

I can do this with little dirty code. like

b = tr.FloatTensor([b[idx] for idx in idx_a])

but I want to use built in function/method in pytorch

any elegant suggestion??

BTW, I use pytorch 2.0

I just figured out by myself.
answer is

b = b[range(2), idx]
1 Like