Hi everyone. Basically, as the title said, my code gets stuck if I try to load a state dict in the model. I’m running this code in a node with 4 gpus so multiprocessing is needed. Here’s is the main loc I use to spawn my 4 different processes using the train() method:
def train(gpu, args, log_dir, models_dir):
torch.cuda.set_device(gpu)
model = Model(...)
if args.path_to_weights:
model.load_state_dict(torch.load(args.path_to_weights, map_location=f'cuda:{gpu}'))
model = model.cuda(gpu)
# distributed data parallel
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
Unfortunately using the torch.load_state_dict() method the code gets stuck. If I remove the loading, instead, the code runs perfectly fine. Can someone has an idea on how to fix this?
Hi Gabriele. I ran into the same problem. In my case there was a mismatch between the state dict of the intantiated model and that from the checkpoint I loaded. In particular, the keys of the instantiated model started with module. whereas those from the checkpoint started with model.
I could avoid the deadlock by making sure the keys matched:
checkpoint = torch.jit.load(
hp.checkpoint_path,
map_location=f"cuda:{rank}"
).state_dict()
checkpoint_renamed = {
k.replace("model.", "module."): v
for k, v in checkpoint.items()}
model.load_state_dict(checkpoint_renamed)
Not sure why such a trivial error should cause a deadlock, though.
I found out that the problem was caused by the cluster’s filesystem that (for some reason unknown to me) could not get the file specified in args.path_to_weights, and instead of throwing some error, it got stuck in an endless loop.
I suppose that this could be caused by the fact that when I saved the model using torch.save, the state_dict also saved the device where the model was located during the training, but since the cluster has thousands of gpus maybe the correct way to load a file is to first pre-allocate it in the cpu, and then load the state_dict. However it is just an hypothesis
Thanks for the update, Gabriele. It seems that deadlocks may arise when the state cannot be loaded for whatever reason in a distributed system. In my case, I got some clues of what was going on by interrupting the training (Ctrl-C) and browsing the error trace.