Profile your code and check where the bottleneck is using the native PyTorch profiler or e.g. Nsight Systems.