How to find input corresponding to m smallest outputs?

As a continuation of this question,

I have a scenario like this,

import torch

def someFn(w):
    return w.sum(dim=1)

w = torch.rand(10,3,6)
fW = someFn(w)

fIndex = fW.argsort(dim=-1)

# Select wmin of size 10,3,2

I want to select the minimum m=2 w that correspond to the two minimum output.

I tried

wmin = w[torch.arange(w.size(0)), :, fIndex[:, :m]]

But it throws error.