Using torch.nonzero on multidimensional tensor for faster GPU implementation

I am trying to get the indices of the non zero elements along a certain dimension in a multidimensional tensor. I tried Implementing torch.nonzero on sliced data along the required dimension and then iterating over the remaining tensor.

This takes quite a long time to execute. Is there a way to do an in-place non-zero index calculation?