Efficiently get a tensor's nonzero elements

I would like to get all nonzero elements along the last dim of a tensor and place them in a different tensor. I assume I know the upper bound of the number of nonzero elements along each row. For example:
Input = torch.tensor([[0,2,0,0,0],

Output = torch.tensor([[0, 2, 0],
[1, 2, 3],
[0, 0, 0],
[0, 0, 2]])
The order of the elements along the last dim is unimportant.

Currently, I do it using sort():
Output = Input.sort(descending=True, dim=-1)[0][:,:3]

I guess there is a more time-efficient way here. Is there?