Resuming a model for training after wrapping in Data Parallel

Hi, I create a model and then wrap it in Data Parallel.
model = DataParallel(model)

Now, can I resume the model from a checkpoint using load_state_dict()?
model.load_state_dict(checkpoint[‘model_state_dict’])

will the new model weights be broadcasted across all gpus?

You could directly load the state_dict, in case it was also stored from an nn.DataParallel model.
Otherwise you would see key errors pointing to the missing .module keys.
I would thus generally recommend to store and load the state_dict from the plain model before wrapping it into nn.DataParallel.

1 Like

Hi @ptrblck, thankyou for your reply.

The stored model was trained in DistributedDataParallel, and I am resuming training in DataParallel. If I load the state_dict from the plain model first, I have to remove the prefix ‘module.’ from the state_dict keys, and then I wrap the model in DataParallel.

If I wrap the model in DataParallel first, and then load the state_dict, I don’t get any key mismatch issues. But I want to know if this approach is correct? Loading the state_dict after wrapping the model in DataParallel essentially means changing the weights of the model. Will the model after loading the state_dict be replicated across all gpus?

Yes, nn.DataParallel will scatter the parameters in each forward pass, so it should work.
This is also the reason for the lower performance compared to DDP, which has a reduced communication overhead.

1 Like

@ptrblck thankyou for replying.

So in DDP the parameters are not scattered in each forward pass? In case of DDP, will wrapping the model first and then loading the state_dict after work as well?

DDP will share the parameters as explained in the docs.

Construction: The DDP constructor takes a reference to the local module, and broadcasts state_dict() from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state. Then, each DDP process creates a local Reducer, which later will take care of the gradients synchronization during the backward pass.

I haven’t checked different workflows of loading the state_dict before and after the model wrapping in util. functions such as DDP, as I’m sticking to initialize the model first and then wrap it into these classes.
However, I guess it should work and you could verify it with a small script by checking the parameters on each rank.

1 Like

Thankyou @ptrblck. This helps a lot.