I have been trying to train/inference various vision models using Pytorch on the AGX Orin device. Across all 3 vision models (MobileNet, ResNet, Yolov8 (Ultralytics)), I have seen that the 1st minibatch always takes a very long time compared to the rest of the minibatches. Another thing I have noticed is the GPU utilization is 0 until the 1st minibatch is completed and only goes up from 2nd minibatch. Likewise, I notice a significant spike (then pretty stable) in power usage after the 1st minibatch. For example, in the case of MobileNet inferencing, the 1st minibatch takes ~2000 ms where as the rest of them take ~30 ms. I am using a Pytorch dataloader to feed images and using wrapper over jtop for logging minibatch times and other metrics. The Pytorch version I am using is 1.12, CUDA version 11.4 and jtop 4.1.2.
I would imagine that this is related to CUDA and PyTorch CUDA subsystem initialization which may involve compilation. One particular aspect might be whether you have a compute architecture that is not foreseen in the PyTorch build. As far as I understand, that causes CUDA to generate PTX on the fly. Others will know more.
I have built and installed the Nvidia recommended versions of PyTorch and torchvision for my Jetson’s JetPack version (PyTorch for Jetson - Announcements - NVIDIA Developer Forums). I have the R34 (release), REVISION: 1.1 JetPack version and exact pytorch version is 1.12.0a0+2c916ef.nv22.3.
So I think that GPU’s compute architecture is compatible with the PyTorch build and PyTorch does not need to generate PTX. But in any case, does generation of PTX take so long?
There arn’t any other process except workload processes running. Also I not worried about CPU utilization… Not sure why GPU is held back until the first minibatch completed.
Still, 2s is fast and I doubt you can lower it any further as it’s also not specific to Jetson devices.
On an x86 workstation:
time python -c "import torch; x = torch.randn(1)"
real 0m0.881s
user 0m1.590s
sys 0m1.665s
time python -c "import torch; x = torch.randn(1).cuda()"
real 0m0.976s
user 0m1.398s
sys 0m1.913s