Proper DistributedDataParallel Usage

I have a couple of questions with regard to the proper usage of DistributedDataParallel that doesn’t seem to be covered anywhere.

The questions are below inline with the code.

def train(device, num_epochs=10):
    model = ToyModel().to(device)
    # QUESTION: Suppose each process has a different random generator state, when
    # `DistributedDataParallel` is initialized does each process need to have the same parameter
    # values?
    ddp_model = nn.DistributedDataParallel(model, device_ids=[device], output_device=device)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()

    for i in range(num_epochs):
        # Training
        model.mode(train=True)
        optimizer.zero_grad()
        outputs = ddp_model(torch.randn(20, 10))
        labels = torch.randn(20, 5).to(device)
        loss_fn(outputs, labels).backward()
        optimizer.step()

        # Evaluate on master
        if torch.distributed.get_rank() == 0:
            model.mode(train=False)
            # QUESTION: In order to evaluate, on one GPU, can we use `ddp_model.module`?
            # QUESTION: Can we use something like `EMA` to copy new parameters to `ddp_model.module`
            # and then restore them after evaluation? Learn more:
            # http://www.programmersought.com/article/28492072406/
            outputs = ddp_model.module(torch.randn(20, 10))
            labels = torch.randn(20, 5).to(device)
            print(loss_fn(outputs, labels))

        # Save checkpoint on master
        if torch.distributed.get_rank() == 0:
            # QUESTION: In order to save the model, can we use `ddp_model.module`?
            torch.save(ddp_model.module, 'checkpoint.pt')

        # QUESTION: Do we need to use `torch.distributed.barrier` so that the other processes
        # don't continue training while the master evaluates?

Thank you for the helpful tutorial https://pytorch.org/tutorials/intermediate/ddp_tutorial.html. I reused it’s example code for this question.

QUESTION: Suppose each process has a different random generator state, when DistributedDataParallel is initialized does each process need to have the same parameter values?

No. Rank 0 will broadcast model states to all other ranks when you construct DDP. Code for that is here.

In order to evaluate, on one GPU, can we use ddp_model.module?

Yes, this should work.

Can we use something like EMA to copy new parameters to ddp_model.module and then restore them after evaluation?

Yes, if you make sure you restored those model param values correctly. Otherwise, if this introduces inconsistency across param values across different processes, DDP will not fix that for you, as DDP only syncs grad instead of params. This might be helpful to explain.

In order to save the model, can we use ddp_model.module

Yes. And when you restore from the checkpoint, it’s better to reconstruct the DDP instance using the restored module to make sure that DDP starts from a clean state.

Do we need to use torch.distributed.barrier so that the other processes don’t continue training while the master evaluates?

It’s recommended this way. But if you are not consuming the checkpoint right away and not worried about timeout due to rank0 is doing more work, this is not necessary. Because the next DDP backward will launch allreduce comm ops, which will sync anyway. Some of this is also explained here.

4 Likes