1st Minibatch takes abnormally large amount of time - Nvidia Jetson AGX Orin


I have been trying to train/inference various vision models using Pytorch on the AGX Orin device. Across all 3 vision models (MobileNet, ResNet, Yolov8 (Ultralytics)), I have seen that the 1st minibatch always takes a very long time compared to the rest of the minibatches. Another thing I have noticed is the GPU utilization is 0 until the 1st minibatch is completed and only goes up from 2nd minibatch. Likewise, I notice a significant spike (then pretty stable) in power usage after the 1st minibatch. For example, in the case of MobileNet inferencing, the 1st minibatch takes ~2000 ms where as the rest of them take ~30 ms. I am using a Pytorch dataloader to feed images and using wrapper over jtop for logging minibatch times and other metrics. The Pytorch version I am using is 1.12, CUDA version 11.4 and jtop 4.1.2.

Does anyone know why this is the case?

I would imagine that this is related to CUDA and PyTorch CUDA subsystem initialization which may involve compilation. One particular aspect might be whether you have a compute architecture that is not foreseen in the PyTorch build. As far as I understand, that causes CUDA to generate PTX on the fly. Others will know more.

Best regards


I have built and installed the Nvidia recommended versions of PyTorch and torchvision for my Jetson’s JetPack version (PyTorch for Jetson - Announcements - NVIDIA Developer Forums). I have the R34 (release), REVISION: 1.1 JetPack version and exact pytorch version is 1.12.0a0+2c916ef.nv22.3.

So I think that GPU’s compute architecture is compatible with the PyTorch build and PyTorch does not need to generate PTX. But in any case, does generation of PTX take so long?

Oh right, you’d be using the JetPack that should have it all figured out.
Can you see what processes are taking the CPU during the wait?

There arn’t any other process except workload processes running. Also I not worried about CPU utilization… Not sure why GPU is held back until the first minibatch completed.

For example, in the case of MobileNet inferencing, the 1st minibatch takes ~2000 ms where as the rest of them take ~30 ms.

I don’t think ~2s to initialize PyTorch and CUDA and infer the first batch is abnormally long and would expect this warmup time.

But this is on MAXN mode on the Orin AGX. If we lower the number of cores or memory frequency, I see the first minibatch taking 15 seconds also.

Still, 2s is fast and I doubt you can lower it any further as it’s also not specific to Jetson devices.
On an x86 workstation:

time python -c "import torch; x = torch.randn(1)"

real	0m0.881s
user	0m1.590s
sys	0m1.665s

time python -c "import torch; x = torch.randn(1).cuda()"

real	0m0.976s
user	0m1.398s
sys	0m1.913s