I want to apply index from torch.max
to another tensor
For example,
a = tr.FloatTensor([4, 1], [3, 10]])
b = tr.FloatTensor([1, 2], [3, 4]])
_, idx_a = tr.max(a, 1) # [0, 1]
b[idx_a] # expected result is [1.4]
# actual result is [[1,2], [3,4]]
I can do this with little dirty code. like
b = tr.FloatTensor([b[idx] for idx in idx_a])
but I want to use built in function/method in pytorch
any elegant suggestion??
BTW, I use pytorch 2.0