As a continuation of this question,
I have a scenario like this,
import torch
def someFn(w):
return w.sum(dim=1)
w = torch.rand(10,3,6)
fW = someFn(w)
fIndex = fW.argsort(dim=-1)
# Select wmin of size 10,3,2
I want to select the minimum m=2 w that correspond to the two minimum output.
I tried
wmin = w[torch.arange(w.size(0)), :, fIndex[:, :m]]
But it throws error.