As a continuation of this question,
I have a scenario like this,
import torch def someFn(w): return w.sum(dim=1) w = torch.rand(10,3,6) fW = someFn(w) fIndex = fW.argsort(dim=-1) # Select wmin of size 10,3,2
I want to select the minimum m=2 w that correspond to the two minimum output.
wmin = w[torch.arange(w.size(0)), :, fIndex[:, :m]]
But it throws error.