Load_state_dict causing deadlock in multiprocessing training

Hi everyone. Basically, as the title said, my code gets stuck if I try to load a state dict in the model. I’m running this code in a node with 4 gpus so multiprocessing is needed. Here’s is the main loc I use to spawn my 4 different processes using the train() method:

torch.multiprocessing.spawn(train, args=(args, log_dir, models_dir), nprocs=args.gpus, join=True)

Here’s some relevant locs of the train() method:

def train(gpu, args, log_dir, models_dir):
    torch.cuda.set_device(gpu)
    model = Model(...)
    if args.path_to_weights:
        model.load_state_dict(torch.load(args.path_to_weights, map_location=f'cuda:{gpu}'))
    
     model = model.cuda(gpu)

    # distributed data parallel
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

Unfortunately using the torch.load_state_dict() method the code gets stuck. If I remove the loading, instead, the code runs perfectly fine. Can someone has an idea on how to fix this?

2 Likes

Hi Gabriele. I ran into the same problem. In my case there was a mismatch between the state dict of the intantiated model and that from the checkpoint I loaded. In particular, the keys of the instantiated model started with module. whereas those from the checkpoint started with model.
I could avoid the deadlock by making sure the keys matched:

checkpoint = torch.jit.load(
    hp.checkpoint_path,
    map_location=f"cuda:{rank}"
).state_dict()
checkpoint_renamed = {
    k.replace("model.", "module."): v
    for k, v in checkpoint.items()}
model.load_state_dict(checkpoint_renamed)

Not sure why such a trivial error should cause a deadlock, though.

1 Like

Daniel, first of all thanks for your answer.

I found out that the problem was caused by the cluster’s filesystem that (for some reason unknown to me) could not get the file specified in args.path_to_weights, and instead of throwing some error, it got stuck in an endless loop.

I suppose that this could be caused by the fact that when I saved the model using torch.save, the state_dict also saved the device where the model was located during the training, but since the cluster has thousands of gpus maybe the correct way to load a file is to first pre-allocate it in the cpu, and then load the state_dict. However it is just an hypothesis

1 Like

Thanks for the update, Gabriele. It seems that deadlocks may arise when the state cannot be loaded for whatever reason in a distributed system. In my case, I got some clues of what was going on by interrupting the training (Ctrl-C) and browsing the error trace.

1 Like