Load_state_dict causing deadlock in multiprocessing training

Hi everyone. Basically, as the title said, my code gets stuck if I try to load a state dict in the model. I’m running this code in a node with 4 gpus so multiprocessing is needed. Here’s is the main loc I use to spawn my 4 different processes using the train() method:

torch.multiprocessing.spawn(train, args=(args, log_dir, models_dir), nprocs=args.gpus, join=True)

Here’s some relevant locs of the train() method:

def train(gpu, args, log_dir, models_dir):
    torch.cuda.set_device(gpu)
    model = Model(...)
    if args.path_to_weights:
        model.load_state_dict(torch.load(args.path_to_weights, map_location=f'cuda:{gpu}'))
    
     model = model.cuda(gpu)

    # distributed data parallel
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

Unfortunately using the torch.load_state_dict() method the code gets stuck. If I remove the loading, instead, the code runs perfectly fine. Can someone has an idea on how to fix this?

3 Likes

Hi Gabriele. I ran into the same problem. In my case there was a mismatch between the state dict of the intantiated model and that from the checkpoint I loaded. In particular, the keys of the instantiated model started with module. whereas those from the checkpoint started with model.
I could avoid the deadlock by making sure the keys matched:

checkpoint = torch.jit.load(
    hp.checkpoint_path,
    map_location=f"cuda:{rank}"
).state_dict()
checkpoint_renamed = {
    k.replace("model.", "module."): v
    for k, v in checkpoint.items()}
model.load_state_dict(checkpoint_renamed)

Not sure why such a trivial error should cause a deadlock, though.

2 Likes

Daniel, first of all thanks for your answer.

I found out that the problem was caused by the cluster’s filesystem that (for some reason unknown to me) could not get the file specified in args.path_to_weights, and instead of throwing some error, it got stuck in an endless loop.

I suppose that this could be caused by the fact that when I saved the model using torch.save, the state_dict also saved the device where the model was located during the training, but since the cluster has thousands of gpus maybe the correct way to load a file is to first pre-allocate it in the cpu, and then load the state_dict. However it is just an hypothesis

1 Like

Thanks for the update, Gabriele. It seems that deadlocks may arise when the state cannot be loaded for whatever reason in a distributed system. In my case, I got some clues of what was going on by interrupting the training (Ctrl-C) and browsing the error trace.

2 Likes

Although not the root cause of OP’s problem (cluster’s filesystem), this was also the same problem I experienced. Simply checking for multiple GPUs and accessing model.module.load_state_dict() on multiple GPUs (vs. model.load_state_dict()) did the trick for me. Pseudo code I used:

    def load_checkpoint(PATH, gpu_id=None):
        device = f"cuda:{gpu_id}" if gpu_id is not None else "cpu"
        print(f"attempting to load checkpoint from: {PATH}")
        cp = torch.load(PATH, map_location=device)
        if device != "cpu":
            model.module.load_state_dict(cp["MODEL"])  # model.module...
            model = DDP(model, device_ids=[gpu_id])
        else:
            model.load_state_dict(cp["MODEL"])  # just model...
        print(f"model object loaded")