I want to apply index from
torch.max to another tensor
a = tr.FloatTensor([4, 1], [3, 10]]) b = tr.FloatTensor([1, 2], [3, 4]]) _, idx_a = tr.max(a, 1) # [0, 1] b[idx_a] # expected result is [1.4] # actual result is [[1,2], [3,4]]
I can do this with little dirty code. like
b = tr.FloatTensor([b[idx] for idx in idx_a])
but I want to use built in function/method in pytorch
any elegant suggestion??
BTW, I use pytorch 2.0