Hi everyone. Basically, as the title said, my code gets stuck if I try to load a state dict in the model. I’m running this code in a node with 4 gpus so multiprocessing is needed. Here’s is the main loc I use to spawn my 4 different processes using the train() method:
def train(gpu, args, log_dir, models_dir):
torch.cuda.set_device(gpu)
model = Model(...)
if args.path_to_weights:
model.load_state_dict(torch.load(args.path_to_weights, map_location=f'cuda:{gpu}'))
model = model.cuda(gpu)
# distributed data parallel
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
Unfortunately using the torch.load_state_dict() method the code gets stuck. If I remove the loading, instead, the code runs perfectly fine. Can someone has an idea on how to fix this?
Hi Gabriele. I ran into the same problem. In my case there was a mismatch between the state dict of the intantiated model and that from the checkpoint I loaded. In particular, the keys of the instantiated model started with module. whereas those from the checkpoint started with model.
I could avoid the deadlock by making sure the keys matched:
checkpoint = torch.jit.load(
hp.checkpoint_path,
map_location=f"cuda:{rank}"
).state_dict()
checkpoint_renamed = {
k.replace("model.", "module."): v
for k, v in checkpoint.items()}
model.load_state_dict(checkpoint_renamed)
Not sure why such a trivial error should cause a deadlock, though.
I found out that the problem was caused by the cluster’s filesystem that (for some reason unknown to me) could not get the file specified in args.path_to_weights, and instead of throwing some error, it got stuck in an endless loop.
I suppose that this could be caused by the fact that when I saved the model using torch.save, the state_dict also saved the device where the model was located during the training, but since the cluster has thousands of gpus maybe the correct way to load a file is to first pre-allocate it in the cpu, and then load the state_dict. However it is just an hypothesis
Thanks for the update, Gabriele. It seems that deadlocks may arise when the state cannot be loaded for whatever reason in a distributed system. In my case, I got some clues of what was going on by interrupting the training (Ctrl-C) and browsing the error trace.
Although not the root cause of OP’s problem (cluster’s filesystem), this was also the same problem I experienced. Simply checking for multiple GPUs and accessing model.module.load_state_dict() on multiple GPUs (vs. model.load_state_dict()) did the trick for me. Pseudo code I used:
def load_checkpoint(PATH, gpu_id=None):
device = f"cuda:{gpu_id}" if gpu_id is not None else "cpu"
print(f"attempting to load checkpoint from: {PATH}")
cp = torch.load(PATH, map_location=device)
if device != "cpu":
model.module.load_state_dict(cp["MODEL"]) # model.module...
model = DDP(model, device_ids=[gpu_id])
else:
model.load_state_dict(cp["MODEL"]) # just model...
print(f"model object loaded")