How to (efficiently) apply a function without a "dim" argument to each row of a 2D tensor?

Hello everyone.

Long story short, I have a 2D matrix of ones and zeros and I need to retrieve, for each row, the indexes of the elements set to one. The “standard” way to do so would be torch.nonzero, but that function is well known for being 1) a real bottleneck, since it does not know in advance the size of the final vector, and 2) it cannot be applied to each row of a 2D tensor in one shot since different rows may have different amounts of ones.

Recently, at::nonzero_static has been introduced, which solves the first point by giving the function the maximum number of nonzero elements (which is fine for my application). However, it does not feature a “dim” argument, meaning that it cannot be applied to each row/column individually, which in my opinion makes no sense since setting the size of the output guarantees that each row would feature the same amount of items, thus making the output a tensor.

Using a for loop would obviously solve my issue, but that would mean calling the function several times which is not GPU efficient. Does anyone know a way to apply nonzero_static efficiently to each row, and returning a tensor where each row is the result of its application to each slice of the tensor?

I will appreciate any help you can provide.
Regards