Mitigate effect of host device synchronization (sync point)

Hi all,
I am currently having the issue of increased runtime most likely due to sync points and would like to have some input from the community.
Unfortunately, I cannot avoid them since I am working with sparse data and need quite a few operations like masked_select, non_zero, etc…

But are there ways to mitigate the effects and keep the GPU busy? I was thinking about things like running operations in parallel that are explicitly not dependent on the “to be synchronized” result.
For example, I have a loop with 4 iterations, each independed and ending with a masked_select, technically those could easily run in parallel. However, due to the sync point each iteration waits for the one before. Is there something I can do to make this more efficient?

I also thought that different streams might be a solution, but since the python part blocks and waits, I cannot queue the remaining streams. Would that work or does pytorch synchronize the entire device and not just the current stream?

I am also open for more experimental ideas to solve the problem :slight_smile:

Thanks a lot in advance!