When loading weight from file with model.load_state_dict(torch.load(model_file))
exception raised: THCudaCheck FAIL file=/data/users/soumith/builder/wheel/pytorch-src/torch/lib/THC/generic/THCStorage.c line=79 error=2 : out of memory Segmentation fault (core dumped)
Previously this runs with no problem, actually two training processes are still running (on another two GPUs), however this breaks when I want to start an additional training process.
OK, I think I’ve got where the problem rises: the model weight saved with torch.save(model.state_dict(), file)
contains device info and torch.load(model_file) will load the weight directly into the device according to the saved device info rather than load into CPU. So, if the previously used device is short of memory, this loading process will crash.
model weights loading...
THCudaCheck FAIL file=/data/users/soumith/builder/wheel/pytorch-src/torch/csrc/generic/serialization.cpp line=145 error=2 : out of memory
Traceback (most recent call last):
File "PTR_evaluation_pytorch.py", line 197, in <module>
model.load_state_dict(torch.load(model_file,map_location=map_loc))
File "/home/David/App/anaconda3/lib/python3.5/site-packages/torch/serialization.py", line 222, in load
return _load(f, map_location, pickle_module)
File "/home/David/App/anaconda3/lib/python3.5/site-packages/torch/serialization.py", line 377, in _load
deserialized_objects[key]._set_from_file(f, offset)
RuntimeError: cuda runtime error (2) : out of memory at /data/users/soumith/builder/wheel/pytorch-src/torch/csrc/generic/serialization.cpp:145
The target device is idle with over 20GB memory free.
there was a bug in the serialization where remapping devices still used the device memory. this is fixed in master. i am working on binaries of version 0.1.11 and that will have this fix.
map_location – a function, torch.device, string or a dict specifying how to remap storage locations
if torch.cuda.is_available() and cfg.use_gpu is not None:
device = torch.device(use_gpu)
else:
device = torch.device("cpu")
checkpoint_data = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint_data['model'])
optimizer.load_state_dict(checkpoint_data['optimizer'])