DataParallel and manually modifying parameters


I’d like to train my model on multiple GPUs but unfortunately I’m getting a massive validation error (but not so when only doing one gpu) after even 1 epoch.

I think the reason is that I manually modify some of the model parameters, after the optimization step:

torch.distributed.init_process_group(backend='nccl', init_method='env://')


model = nn.Sequential(...).cuda()
dmodel = nn.DataParallel(model)

    loss = criterion(dmodel(x),y)
    with torch.no_grad():
        deterministic_modify(model[17]) # I need to manually modify some weights.

I’m guessing there is a sync problem, because if I do this on a single gpu, things work as expected. But on multiple gpu’s I get terrible validation error.

The way I understand nn.DataParallel works (but please correct me if I’m wrong) is that it’s a wrapper, and each gpu has a copy of the model, and nn.DataParallel splits the batch into two, gives each gpu half, computes the gradients, and then, somehow, sync’s the model in both gpus (how?).


Above is correct except the model sync part. In the forward pass, it replicates the model to all devices, creates one thread per device, scatters (uses broadcast) the input to all devices (so that each thread exclusively works on one model replica with a input split), and finally gathers the output from all threads, and uses that as the return value of the forward function.

So after the forward pass, the autograd graph will contain multiple replicas of the original model, which all points back to the same original model. In the backward pass, each model replica will compute there own gradients, then, as they all have autograd edges point back to the original model, those gradients on different replicates will also accumulate into the same original module. So, it is not synchronizing across replicas, instead all gradients from all replicates are accumulated into the original module.

I noticed that the code snippet above calls init_process_group, which is required for DistributedDataParallel but not necessary for DataParallel. And DistributedDataParallel indeed does the gradient sync across multiple processes, and which should be faster than DataParallel.