Something strange with torch.load

Hi all,

I currently try to make prediction using this pretrained network StereoNet.

After some effort (there must be some push errors on this repositery), I manage to load the pre-trained network. And unfortunately, a nice message warns me that my GPU is too low in RAM. So I removed everything that made an explicit call to cuda and specified :

device = torch.device(“cpu”)


Then I ran the prediction on an image pair while having a terminal with nvidia-smi running in loop mode. The machine became unusable and I had the time to notice that python was consuming 500Mb of VRAM, obviously endless exchanges between CPU and GPU. After a reboot, I searched for which line of my python script the GPU was being used.
It turns out that it was at the time of loading the network weights :slight_smile:

data = torch.load(“checkpoint_pretrain_secneflow.pth”)[‘state_dict’]

I followed all the other lines avoiding the prediction step and at no other time did the use of VRAM move.

Does anyone have an idea to avoid this GPU usage and why the simple reading of a 2.6 MB file generates the use of 500 MB of VRAM before even attaching it to a network?

Thank you in advance for your help!


The cuda state on the GPU is quite big. Taking few hundreds of MB.

To avoid loading on the GPU a model saved on the gpu, you can do: torch.load('checkpoint_pretrain_secneflow.pth', map_location=torch.device('cpu')).