samarendra109
(samarendra chandan bindu Dash)
1
I have a basic scenario like this,
import torch
def someFn(w):
return w.sum(dim=1)
w = torch.rand(10,3,2)
fW = someFn(w)
fMinW, fIndex = fW.min(dim=1)
# Select wmin of size 10,3
Now I want to select the 10 (1x3) vectors that correspond to the fMinW output. How can I achieve this ?
Assuming the desired output can be created via:
res2 = []
for idx in range(w.size(0)):
res2.append(w[idx, :, fIndex[idx]])
res2 = torch.stack(res2)
then this should work:
res1 = w[torch.arange(w.size(0)), :, fIndex]
# compare
print((res2 == res1).all())
> tensor(True)
1 Like