How to find input corresponding to argmin of output?

I have a basic scenario like this,

import torch

def someFn(w):
    
    return w.sum(dim=1)

w = torch.rand(10,3,2)
fW = someFn(w)

fMinW, fIndex = fW.min(dim=1)

# Select wmin of size 10,3

Now I want to select the 10 (1x3) vectors that correspond to the fMinW output. How can I achieve this ?

Assuming the desired output can be created via:

res2 = []
for idx in range(w.size(0)):
    res2.append(w[idx, :, fIndex[idx]])
res2 = torch.stack(res2)

then this should work:

res1 = w[torch.arange(w.size(0)), :, fIndex]

# compare
print((res2 == res1).all())
> tensor(True)
1 Like