I would like to get all nonzero elements along the last dim of a tensor and place them in a different tensor. I assume I know the upper bound of the number of nonzero elements along each row. For example:

Input = torch.tensor([[0,2,0,0,0],

[1,2,0,0,3],

[0,0,0,0,0],

[0,0,0,0,2]])

Output = torch.tensor([[0, 2, 0],

[1, 2, 3],

[0, 0, 0],

[0, 0, 2]])

The order of the elements along the last dim is unimportant.

Currently, I do it using sort():

Output = Input.sort(descending=True, dim=-1)[0][:,:3]

I guess there is a more time-efficient way here. Is there?