Profiling PyTorch: pred.cpu() is reponsible for 95% of runtime?

Hi,
I think the problem here is that all cuda calls made with pytorch are asynchronous, and so the only point where it actually waits for the computations to be done is when you use them: when you want to send them back to the cpu.
You can either use torch.cuda.synchronize() after some operations to force synchronization.
You can also try setting the CUDA_LAUNCH_BLOCKING=1 environment variable to force the CUDA api to become synchronous. Note that this will most certainly slow down your code.

1 Like