Reducing compiled call overhead

Hello, I have been optimizing a Soft Actor Critic implementation of mine using torch compile, and I feel like I am hitting a limit in how much I can reduce cuda calls overhead. It looks like this is because the PyTorch code involved in launching the optimized code itself is quite complex in itself, and, as such, a long time passes between my Python function call and when my CUDA Graphs are being launched. I attach here a screenshot from Nsight as a reference.


In the image it can be seen how a long time passes between when I call the compiled function (approximately at the start of the _compute_actor_and_alpha_loss range) and when the CudaTrees run gets called and when finally the actual cudagraph is launched. I suspect this is because of time spent on the CPU running Python code. Is this normal behaviour? Can I do something to reduce this apart from merging things in one single cuda graph as much as possible? Are there any plans by PyTorch to reduce this kind of inefficiencies? In use cases like mine this slows down things quite a lot. Thanks!