First torch import and .to(device) very slow when using a cached conda package

I’m running into this rather odd behaviour.

As per the title: the first torch import and .to(device) is very slow (upwards of 2 minutes) when using an environment where pytorch was installed from the conda cache (i.e., it wasn’t downloaded).

This does not happen if conda has to download pytorch (e.g. when I’m using a new version that hadn’t been installed yet).

This seems to happen regardless of pytorch version or GPU type. I’m using AWS images (both my own and the AWS Deep Learning Base AMI). This happens for both.

In fact, AWS Deep Learning Base AMI already comes with the environments ready to activate, and the first import still takes a long time.

Any idea as to why freshly downloaded pytorch runs quickly at the first call, whereas cached/preinstalled pytorch doesn’t?

Could you post the installed PyTorch version, how you’ve installed it, which CUDA toolkit (or runtime in the binary) and OS you were using, and which GPU?
The issue description would point towards a call into the CUDA JIT compiler, which would be triggered to compile the kernels for a missing GPU architecture.

This has happened with conda installed pytorch=<1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1>, cudatoolkit=<10.1,10.2,11.0,11.1>, with ubuntu18. nvidia-smi shows the driver is version 460.

This happens with K80, T4, and V100 GPUs.

The puzzling bit is the fact that if conda has to download the package, it imports fast; if it uses the cached package, it takes a long time.

Thanks for the information. This sounds rather like a installation issue with the conda cache and I don’t know what might be causing it.

I had the same problem.
I think the reason because conda does not see torch location (when I do “conda list” there is no torch package, eventhough torchvison and torchaudio are there).
I tried to install torch again using pip - in stead of conda install, the problem disappears. I don’t know if it will cause any incompatibility, but it solved my “slow import torch” problem.

Did you by chance check the torch.__path__ after importing it in the “slow” approach?
Based on your description it seems that your conda environment didn’t even had torch as a package in its list, so it would be interesting where it’s even able to find it.
I guess it’s (somehow) finding an older installation, which might trigger the CUDA JIT for your compute capability, but I would rather expect conda to raise an error of a missing PyTorch installation.

Sorry, I didn’t check the torch.path, and I don’t want to roll back to the “slow” approach. It’s strange that

conda install torch

successfully, but conda list does not show torch . I guess it’s a conda issue.

1 Like

Hi @ptrblck,
I have 2 conda environments, one with torch v1.9.0+cu111 still has slow import problem, the other one with torch v0.4.1 is fine.
conda list shows torch in both and
torch.__path__ are similar in both environments which is
Does it give you any hint about what’s wrong?

I think the problem is with cuda version. I have that problem with PyTorch1.9 + cuda11, it dissappeared with PyTorch1.9 + cuda10

I don’t believe it’s a pure binary issue, as I haven’t seen this behavior on any system I’m using.
Based on your previous description:

I would still assume that conda finds multiple installations and triggers the CUDA JIT compiler for missing architectures.
Let me know, if you have a way to reproduce this issue.

I think you are right.
If I load cuda11 before import torch, the problem dissappers. Which is weird since torch binay installed its own cuda, I didn’t have to load cuda separately ever before.
It’s a little complicated to reproduce, I’m running torch in an HPC which has some dependencies in path of packages.
Thank you very much for your comments, adding one more line of code to load cuda11 is trivial.